attention_kernels.cu 41.8 KB
Newer Older
1
2
/*
 * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
Woosuk Kwon's avatar
Woosuk Kwon committed
3
 * Copyright (c) 2023, The vLLM team.
4
5
6
7
8
9
10
11
12
13
14
15
16
17
 * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
18
19
20
21
#ifdef USE_ROCM
#include <hip/hip_runtime.h>
#endif

Woosuk Kwon's avatar
Woosuk Kwon committed
22
23
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
24
#include <c10/cuda/CUDAGuard.h>
Woosuk Kwon's avatar
Woosuk Kwon committed
25

Woosuk Kwon's avatar
Woosuk Kwon committed
26
#include "attention_dtypes.h"
Woosuk Kwon's avatar
Woosuk Kwon committed
27
#include "attention_utils.cuh"
28
#ifdef ENABLE_FP8_E5M2
29
#include "../quantization/fp8_e5m2_kvcache/quant_utils.cuh"
30
#endif
Woosuk Kwon's avatar
Woosuk Kwon committed
31
32
33

#include <algorithm>

34
#ifndef USE_ROCM
Woosuk Kwon's avatar
Woosuk Kwon committed
35
#define WARP_SIZE 32
36
37
38
#else
#define WARP_SIZE warpSize
#endif
Woosuk Kwon's avatar
Woosuk Kwon committed
39
40
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
41
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
Woosuk Kwon's avatar
Woosuk Kwon committed
42

Woosuk Kwon's avatar
Woosuk Kwon committed
43
namespace vllm {
Woosuk Kwon's avatar
Woosuk Kwon committed
44
45
46
47
48
49
50
51
52
53
54

// Utility function for attention softmax.
template<int NUM_WARPS>
inline __device__ float block_sum(float* red_smem, float sum) {
  // Decompose the thread index into warp / lane.
  int warp = threadIdx.x / WARP_SIZE;
  int lane = threadIdx.x % WARP_SIZE;

  // Compute the sum per warp.
#pragma unroll
  for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
55
    sum += VLLM_SHFL_XOR_SYNC(sum, mask);
Woosuk Kwon's avatar
Woosuk Kwon committed
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
  }

  // Warp leaders store the data to shared memory.
  if (lane == 0) {
    red_smem[warp] = sum;
  }

  // Make sure the data is in shared memory.
  __syncthreads();

  // The warps compute the final sums.
  if (lane < NUM_WARPS) {
    sum = red_smem[lane];
  }

  // Parallel reduction inside the warp.
#pragma unroll
  for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
74
    sum += VLLM_SHFL_XOR_SYNC(sum, mask);
Woosuk Kwon's avatar
Woosuk Kwon committed
75
76
77
  }

  // Broadcast to other threads.
78
  return VLLM_SHFL_SYNC(sum, 0);
Woosuk Kwon's avatar
Woosuk Kwon committed
79
80
}

81
82
// TODO(woosuk): Merge the last two dimensions of the grid.
// Grid: (num_heads, num_seqs, max_num_partitions).
Woosuk Kwon's avatar
Woosuk Kwon committed
83
84
template<
  typename scalar_t,
85
  typename cache_t,
Woosuk Kwon's avatar
Woosuk Kwon committed
86
87
  int HEAD_SIZE,
  int BLOCK_SIZE,
88
  int NUM_THREADS,
89
  bool IS_FP8_E5M2_KV_CACHE,
90
91
92
93
94
  int PARTITION_SIZE = 0> // Zero means no partitioning.
__device__ void paged_attention_kernel(
  float* __restrict__ exp_sums,           // [num_seqs, num_heads, max_num_partitions]
  float* __restrict__ max_logits,         // [num_seqs, num_heads, max_num_partitions]
  scalar_t* __restrict__ out,             // [num_seqs, num_heads, max_num_partitions, head_size]
Woosuk Kwon's avatar
Woosuk Kwon committed
95
  const scalar_t* __restrict__ q,         // [num_seqs, num_heads, head_size]
96
97
  const cache_t* __restrict__ k_cache,    // [num_blocks, num_kv_heads, head_size/x, block_size, x]
  const cache_t* __restrict__ v_cache,    // [num_blocks, num_kv_heads, head_size, block_size]
98
  const int num_kv_heads,                 // [num_heads]
Woosuk Kwon's avatar
Woosuk Kwon committed
99
100
101
102
  const float scale,
  const int* __restrict__ block_tables,   // [num_seqs, max_num_blocks_per_seq]
  const int* __restrict__ context_lens,   // [num_seqs]
  const int max_num_blocks_per_seq,
Woosuk Kwon's avatar
Woosuk Kwon committed
103
  const float* __restrict__ alibi_slopes, // [num_heads]
Zhuohan Li's avatar
Zhuohan Li committed
104
105
106
  const int q_stride,
  const int kv_block_stride,
  const int kv_head_stride) {
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
  const int seq_idx = blockIdx.y;
  const int partition_idx = blockIdx.z;
  const int max_num_partitions = gridDim.z;
  constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0;
  const int context_len = context_lens[seq_idx];
  if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= context_len) {
    // No work to do. Terminate the thread block.
    return;
  }

  const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
  const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks;

  // [start_block_idx, end_block_idx) is the range of blocks to process.
  const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
  const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_context_blocks);
  const int num_blocks = end_block_idx - start_block_idx;

  // [start_token_idx, end_token_idx) is the range of tokens to process.
  const int start_token_idx = start_block_idx * BLOCK_SIZE;
  const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len);
  const int num_tokens = end_token_idx - start_token_idx;

Woosuk Kwon's avatar
Woosuk Kwon committed
130
  constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
131
132
  constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS
  assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
133
  constexpr int NUM_TOKENS_PER_THREAD_GROUP = DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE);
Woosuk Kwon's avatar
Woosuk Kwon committed
134
135
136
137
138
139
140
  constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
  const int thread_idx = threadIdx.x;
  const int warp_idx = thread_idx / WARP_SIZE;
  const int lane = thread_idx % WARP_SIZE;

  const int head_idx = blockIdx.x;
  const int num_heads = gridDim.x;
141
142
  const int num_queries_per_kv = num_heads / num_kv_heads;
  const int kv_head_idx = head_idx / num_queries_per_kv;
Woosuk Kwon's avatar
Woosuk Kwon committed
143
  const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
Woosuk Kwon's avatar
Woosuk Kwon committed
144
145
146
147
148
149
150
151
152

  // A vector type to store a part of a key or a query.
  // The vector size is configured in such a way that the threads in a thread group
  // fetch or compute 16 bytes at a time.
  // For example, if the size of a thread group is 4 and the data type is half,
  // then the vector size is 16 / (4 * sizeof(half)) == 2.
  constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
  using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
  using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
153
154
155
#ifdef ENABLE_FP8_E5M2
  using Quant_vec = typename Vec<cache_t, VEC_SIZE>::Type;
#endif
Woosuk Kwon's avatar
Woosuk Kwon committed
156
157
158
159
160
161
162
163
164
165
166
167
168
169

  constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
  constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;

  const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE;
  const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE;

  // Load the query to registers.
  // Each thread in a thread group has a different part of the query.
  // For example, if the the thread group size is 4, then the first thread in the group
  // has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ...
  // th vectors of the query, and so on.
  // NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous.
  const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
170
  __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
Woosuk Kwon's avatar
Woosuk Kwon committed
171
#pragma unroll
172
  for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) {
Woosuk Kwon's avatar
Woosuk Kwon committed
173
    const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
174
    q_vecs[thread_group_offset][i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
Woosuk Kwon's avatar
Woosuk Kwon committed
175
  }
176
  __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs
Woosuk Kwon's avatar
Woosuk Kwon committed
177
178
179
180
181
182
183
184
185
186

  // Memory planning.
  extern __shared__ char shared_mem[];
  // NOTE(woosuk): We use FP32 for the softmax logits for better accuracy.
  float* logits = reinterpret_cast<float*>(shared_mem);
  // Workspace for reduction.
  __shared__ float red_smem[2 * NUM_WARPS];

  // x == THREAD_GROUP_SIZE * VEC_SIZE
  // Each thread group fetches x elements from the key at a time.
187
  constexpr int x = 16 / sizeof(cache_t);
Woosuk Kwon's avatar
Woosuk Kwon committed
188
189
190
191
192
193
  float qk_max = -FLT_MAX;

  // Iterate over the key blocks.
  // Each warp fetches a block of keys for each iteration.
  // Each thread group in a warp fetches a key from the block, and computes
  // dot product with the query.
194
195
  const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
  for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
196
197
198
199
    // NOTE(woosuk): The block number is stored in int32. However, we cast it to int64
    // because int32 can lead to overflow when this variable is multiplied by large numbers
    // (e.g., kv_block_stride).
    const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
Woosuk Kwon's avatar
Woosuk Kwon committed
200
201
202
203
204
205
206
207
208
209
210
211
212

    // Load a key to registers.
    // Each thread in a thread group has a different part of the key.
    // For example, if the the thread group size is 4, then the first thread in the group
    // has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th
    // vectors of the key, and so on.
    for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
      const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
      const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
      K_vec k_vecs[NUM_VECS_PER_THREAD];

#pragma unroll
      for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
213
214
215
        const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride
                                       + kv_head_idx * kv_head_stride
                                       + physical_block_offset * x;
Woosuk Kwon's avatar
Woosuk Kwon committed
216
217
218
        const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
        const int offset1 = (vec_idx * VEC_SIZE) / x;
        const int offset2 = (vec_idx * VEC_SIZE) % x;
219
220
221
222
223
224
225
226
227
228
229
        if constexpr (IS_FP8_E5M2_KV_CACHE) {
#ifdef ENABLE_FP8_E5M2
          Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
          // Vector conversion from Quant_vec to K_vec.
          k_vecs[j] = fp8_e5m2_unscaled::vec_conversion<K_vec, Quant_vec>(k_vec_quant);
#else
          assert(false);
#endif
        } else {
          k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
        }
Woosuk Kwon's avatar
Woosuk Kwon committed
230
231
232
233
      }

      // Compute dot product.
      // This includes a reduction across the threads in the same thread group.
234
      float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
Woosuk Kwon's avatar
Woosuk Kwon committed
235
      // Add the ALiBi bias if slopes are given.
236
      qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0;
Woosuk Kwon's avatar
Woosuk Kwon committed
237

Woosuk Kwon's avatar
Woosuk Kwon committed
238
239
240
      if (thread_group_offset == 0) {
        // Store the partial reductions to shared memory.
        // NOTE(woosuk): It is required to zero out the masked logits.
Woosuk Kwon's avatar
Woosuk Kwon committed
241
        const bool mask = token_idx >= context_len;
242
        logits[token_idx - start_token_idx] = mask ? 0.f : qk;
Woosuk Kwon's avatar
Woosuk Kwon committed
243
244
245
246
247
248
249
250
251
252
253
        // Update the max value.
        qk_max = mask ? qk_max : fmaxf(qk_max, qk);
      }
    }
  }

  // Perform reduction across the threads in the same warp to get the
  // max qk value for each "warp" (not across the thread block yet).
  // The 0-th thread of each thread group already has its max qk value.
#pragma unroll
  for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
254
    qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
Woosuk Kwon's avatar
Woosuk Kwon committed
255
256
257
258
259
260
261
262
263
264
265
  }
  if (lane == 0) {
    red_smem[warp_idx] = qk_max;
  }
  __syncthreads();

  // TODO(woosuk): Refactor this part.
  // Get the max qk value for the sequence.
  qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
#pragma unroll
  for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
266
    qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
Woosuk Kwon's avatar
Woosuk Kwon committed
267
268
  }
  // Broadcast the max qk value to all threads.
269
  qk_max = VLLM_SHFL_SYNC(qk_max, 0);
Woosuk Kwon's avatar
Woosuk Kwon committed
270
271
272

  // Get the sum of the exp values.
  float exp_sum = 0.f;
273
  for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
Woosuk Kwon's avatar
Woosuk Kwon committed
274
275
276
277
278
279
280
281
    float val = __expf(logits[i] - qk_max);
    logits[i] = val;
    exp_sum += val;
  }
  exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);

  // Compute softmax.
  const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
282
  for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
Woosuk Kwon's avatar
Woosuk Kwon committed
283
284
285
286
    logits[i] *= inv_sum;
  }
  __syncthreads();

287
288
289
290
291
292
293
294
295
296
297
298
  // If partitioning is enabled, store the max logit and exp_sum.
  if (USE_PARTITIONING && thread_idx == 0) {
    float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions
                                       + head_idx * max_num_partitions
                                       + partition_idx;
    *max_logits_ptr = qk_max;
    float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions
                                   + head_idx * max_num_partitions
                                   + partition_idx;
    *exp_sums_ptr = exp_sum;
  }

Woosuk Kwon's avatar
Woosuk Kwon committed
299
300
301
302
  // Each thread will fetch 16 bytes from the value cache at a time.
  constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
  using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
  using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
303
304
305
#ifdef ENABLE_FP8_E5M2
  using V_quant_vec = typename Vec<cache_t, V_VEC_SIZE>::Type;
#endif
Woosuk Kwon's avatar
Woosuk Kwon committed
306
307
308
309
  using Float_L_vec = typename FloatVec<L_vec>::Type;

  constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
  constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
310
  constexpr int NUM_ROWS_PER_THREAD = DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
Woosuk Kwon's avatar
Woosuk Kwon committed
311
312
313
314
315
316
317
318

  // NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
  float accs[NUM_ROWS_PER_THREAD];
#pragma unroll
  for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
    accs[i] = 0.f;
  }

319
320
  scalar_t zero_value;
  zero(zero_value);
321
  for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
322
323
324
325
    // NOTE(woosuk): The block number is stored in int32. However, we cast it to int64
    // because int32 can lead to overflow when this variable is multiplied by large numbers
    // (e.g., kv_block_stride).
    const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
Woosuk Kwon's avatar
Woosuk Kwon committed
326
327
328
    const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
    const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
    L_vec logits_vec;
329
    from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx - start_token_idx));
Woosuk Kwon's avatar
Woosuk Kwon committed
330

331
332
    const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride
                                   + kv_head_idx * kv_head_stride;
Woosuk Kwon's avatar
Woosuk Kwon committed
333
334
335
336
337
#pragma unroll
    for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
      const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
      if (row_idx < HEAD_SIZE) {
        const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
338
339
340
341
342
343
344
345
346
347
348
349
        V_vec v_vec;
        if constexpr (IS_FP8_E5M2_KV_CACHE) {
#ifdef ENABLE_FP8_E5M2
          V_quant_vec v_quant_vec = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
          // Vector conversion from V_quant_vec to V_vec.
          v_vec = fp8_e5m2_unscaled::vec_conversion<V_vec, V_quant_vec>(v_quant_vec);
#else
          assert(false);
#endif
        } else {
          v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
        }
350
        if (block_idx == num_context_blocks - 1) {
351
352
353
354
355
          // NOTE(woosuk): When v_vec contains the tokens that are out of the context,
          // we should explicitly zero out the values since they may contain NaNs.
          // See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
          scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
#pragma unroll
356
          for (int j = 0; j < V_VEC_SIZE; j++) {
357
358
359
            v_vec_ptr[j] = token_idx + j < context_len ? v_vec_ptr[j] : zero_value;
          }
        }
Woosuk Kwon's avatar
Woosuk Kwon committed
360
361
362
363
364
365
366
367
368
369
370
        accs[i] += dot(logits_vec, v_vec);
      }
    }
  }

  // Perform reduction within each warp.
#pragma unroll
  for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
    float acc = accs[i];
#pragma unroll
    for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
371
      acc += VLLM_SHFL_XOR_SYNC(acc, mask);
Woosuk Kwon's avatar
Woosuk Kwon committed
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
    }
    accs[i] = acc;
  }

  // NOTE(woosuk): A barrier is required because the shared memory space for logits
  // is reused for the output.
  __syncthreads();

  // Perform reduction across warps.
  float* out_smem = reinterpret_cast<float*>(shared_mem);
#pragma unroll
  for (int i = NUM_WARPS; i > 1; i /= 2) {
    int mid = i / 2;
    // Upper warps write to shared memory.
    if (warp_idx >= mid && warp_idx < i) {
      float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
#pragma unroll
      for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
        const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
        if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
          dst[row_idx] = accs[i];
        }
      }
    }
    __syncthreads();

    // Lower warps update the output.
    if (warp_idx < mid) {
      const float* src = &out_smem[warp_idx * HEAD_SIZE];
#pragma unroll
      for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
        const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
        if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
          accs[i] += src[row_idx];
        }
      }
    }
    __syncthreads();
  }

  // Write the final output.
  if (warp_idx == 0) {
414
415
416
    scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
                            + head_idx * max_num_partitions * HEAD_SIZE
                            + partition_idx * HEAD_SIZE;
Woosuk Kwon's avatar
Woosuk Kwon committed
417
418
419
420
421
422
423
424
425
426
#pragma unroll
    for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
      const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
      if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
        from_float(*(out_ptr + row_idx), accs[i]);
      }
    }
  }
}

427
428
429
// Grid: (num_heads, num_seqs, 1).
template<
  typename scalar_t,
430
  typename cache_t,
431
432
  int HEAD_SIZE,
  int BLOCK_SIZE,
433
434
  int NUM_THREADS,
  bool IS_FP8_E5M2_KV_CACHE>
435
436
437
__global__ void paged_attention_v1_kernel(
  scalar_t* __restrict__ out,             // [num_seqs, num_heads, head_size]
  const scalar_t* __restrict__ q,         // [num_seqs, num_heads, head_size]
438
439
  const cache_t* __restrict__ k_cache,    // [num_blocks, num_kv_heads, head_size/x, block_size, x]
  const cache_t* __restrict__ v_cache,    // [num_blocks, num_kv_heads, head_size, block_size]
440
  const int num_kv_heads,                 // [num_heads]
441
442
443
444
445
446
447
448
  const float scale,
  const int* __restrict__ block_tables,   // [num_seqs, max_num_blocks_per_seq]
  const int* __restrict__ context_lens,   // [num_seqs]
  const int max_num_blocks_per_seq,
  const float* __restrict__ alibi_slopes, // [num_heads]
  const int q_stride,
  const int kv_block_stride,
  const int kv_head_stride) {
449
  paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_E5M2_KV_CACHE>(
450
    /* exp_sums */ nullptr, /* max_logits */ nullptr,
451
    out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens,
452
453
454
455
456
457
    max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride);
}

// Grid: (num_heads, num_seqs, max_num_partitions).
template<
  typename scalar_t,
458
  typename cache_t,
459
460
461
  int HEAD_SIZE,
  int BLOCK_SIZE,
  int NUM_THREADS,
462
  bool IS_FP8_E5M2_KV_CACHE,
463
464
465
466
467
468
  int PARTITION_SIZE>
__global__ void paged_attention_v2_kernel(
  float* __restrict__ exp_sums,           // [num_seqs, num_heads, max_num_partitions]
  float* __restrict__ max_logits,         // [num_seqs, num_heads, max_num_partitions]
  scalar_t* __restrict__ tmp_out,         // [num_seqs, num_heads, max_num_partitions, head_size]
  const scalar_t* __restrict__ q,         // [num_seqs, num_heads, head_size]
469
470
  const cache_t* __restrict__ k_cache,    // [num_blocks, num_kv_heads, head_size/x, block_size, x]
  const cache_t* __restrict__ v_cache,    // [num_blocks, num_kv_heads, head_size, block_size]
471
  const int num_kv_heads,                 // [num_heads]
472
473
474
475
476
477
478
479
  const float scale,
  const int* __restrict__ block_tables,   // [num_seqs, max_num_blocks_per_seq]
  const int* __restrict__ context_lens,   // [num_seqs]
  const int max_num_blocks_per_seq,
  const float* __restrict__ alibi_slopes, // [num_heads]
  const int q_stride,
  const int kv_block_stride,
  const int kv_head_stride) {
480
  paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_E5M2_KV_CACHE, PARTITION_SIZE>(
481
    exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
    block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes,
    q_stride, kv_block_stride, kv_head_stride);
}

// Grid: (num_heads, num_seqs).
template<
  typename scalar_t,
  int HEAD_SIZE,
  int NUM_THREADS,
  int PARTITION_SIZE>
__global__ void paged_attention_v2_reduce_kernel(
  scalar_t* __restrict__ out,             // [num_seqs, num_heads, head_size]
  const float* __restrict__ exp_sums,     // [num_seqs, num_heads, max_num_partitions]
  const float* __restrict__ max_logits,   // [num_seqs, num_heads, max_num_partitions]
  const scalar_t* __restrict__ tmp_out,   // [num_seqs, num_heads, max_num_partitions, head_size]
  const int* __restrict__ context_lens,   // [num_seqs]
  const int max_num_partitions) {
  const int num_heads = gridDim.x;
  const int head_idx = blockIdx.x;
  const int seq_idx = blockIdx.y;
  const int context_len = context_lens[seq_idx];
  const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
  if (num_partitions == 1) {
    // No need to reduce. Only copy tmp_out to out.
    scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
    const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
                                          + head_idx * max_num_partitions * HEAD_SIZE;
    for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) {
      out_ptr[i] = tmp_out_ptr[i];
    }
    // Terminate the thread block.
    return;
  }

  constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
  const int warp_idx = threadIdx.x / WARP_SIZE;
  const int lane = threadIdx.x % WARP_SIZE;

  // Size: 2 * num_partitions.
  extern __shared__ char shared_mem[];
  // Workspace for reduction.
  __shared__ float red_smem[2 * NUM_WARPS];

  // Load max logits to shared memory.
  float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
  const float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions
                                           + head_idx * max_num_partitions;
  float max_logit = -FLT_MAX;
  for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
    const float l = max_logits_ptr[i];
    shared_max_logits[i] = l;
    max_logit = fmaxf(max_logit, l);
  }
  __syncthreads();

  // Get the global max logit.
  // Reduce within the warp.
#pragma unroll
  for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
541
    max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
542
543
544
545
546
547
548
549
550
  }
  if (lane == 0) {
    red_smem[warp_idx] = max_logit;
  }
  __syncthreads();
  // Reduce across warps.
  max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
#pragma unroll
  for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
551
    max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
552
553
  }
  // Broadcast the max value to all threads.
554
  max_logit = VLLM_SHFL_SYNC(max_logit, 0);
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584

  // Load rescaled exp sums to shared memory.
  float* shared_exp_sums = reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
  const float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions
                                       + head_idx * max_num_partitions;
  float global_exp_sum = 0.0f;
  for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
    float l = shared_max_logits[i];
    float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit);
    global_exp_sum += rescaled_exp_sum;
    shared_exp_sums[i] = rescaled_exp_sum;
  }
  __syncthreads();
  global_exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], global_exp_sum);
  const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);

  // Aggregate tmp_out to out.
  const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
                                        + head_idx * max_num_partitions * HEAD_SIZE;
  scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
#pragma unroll
  for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
    float acc = 0.0f;
    for (int j = 0; j < num_partitions; ++j) {
      acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * inv_global_exp_sum;
    }
    from_float(out_ptr[i], acc);
  }
}

Woosuk Kwon's avatar
Woosuk Kwon committed
585
} // namespace vllm
Woosuk Kwon's avatar
Woosuk Kwon committed
586

587
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE)                                                  \
588
  VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(                                       \
589
590
591
592
    ((void*)vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,   \
      IS_FP8_E5M2_KV_CACHE>), shared_mem_size);                                               \
  vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,             \
  IS_FP8_E5M2_KV_CACHE><<<grid, block, shared_mem_size, stream>>>(                            \
Woosuk Kwon's avatar
Woosuk Kwon committed
593
594
595
596
    out_ptr,                                                                                  \
    query_ptr,                                                                                \
    key_cache_ptr,                                                                            \
    value_cache_ptr,                                                                          \
597
    num_kv_heads,                                                                             \
Woosuk Kwon's avatar
Woosuk Kwon committed
598
599
600
601
    scale,                                                                                    \
    block_tables_ptr,                                                                         \
    context_lens_ptr,                                                                         \
    max_num_blocks_per_seq,                                                                   \
Woosuk Kwon's avatar
Woosuk Kwon committed
602
    alibi_slopes_ptr,                                                                         \
Zhuohan Li's avatar
Zhuohan Li committed
603
604
605
    q_stride,                                                                                 \
    kv_block_stride,                                                                          \
    kv_head_stride);
Woosuk Kwon's avatar
Woosuk Kwon committed
606
607
608
609

// TODO(woosuk): Tune NUM_THREADS.
template<
  typename T,
610
  typename CACHE_T,
Woosuk Kwon's avatar
Woosuk Kwon committed
611
  int BLOCK_SIZE,
612
  bool IS_FP8_E5M2_KV_CACHE,
Woosuk Kwon's avatar
Woosuk Kwon committed
613
  int NUM_THREADS = 128>
614
void paged_attention_v1_launcher(
Woosuk Kwon's avatar
Woosuk Kwon committed
615
616
617
618
  torch::Tensor& out,
  torch::Tensor& query,
  torch::Tensor& key_cache,
  torch::Tensor& value_cache,
619
  int num_kv_heads,
Woosuk Kwon's avatar
Woosuk Kwon committed
620
621
622
  float scale,
  torch::Tensor& block_tables,
  torch::Tensor& context_lens,
Woosuk Kwon's avatar
Woosuk Kwon committed
623
624
  int max_context_len,
  const c10::optional<torch::Tensor>& alibi_slopes) {
Woosuk Kwon's avatar
Woosuk Kwon committed
625
626
627
628
  int num_seqs = query.size(0);
  int num_heads = query.size(1);
  int head_size = query.size(2);
  int max_num_blocks_per_seq = block_tables.size(1);
Zhuohan Li's avatar
Zhuohan Li committed
629
630
631
  int q_stride = query.stride(0);
  int kv_block_stride = key_cache.stride(0);
  int kv_head_stride = key_cache.stride(1);
Woosuk Kwon's avatar
Woosuk Kwon committed
632
633
634
635

  int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
  assert(head_size % thread_group_size == 0);

Woosuk Kwon's avatar
Woosuk Kwon committed
636
637
638
639
640
  // NOTE: alibi_slopes is optional.
  const float* alibi_slopes_ptr = alibi_slopes ?
    reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
    : nullptr;

Woosuk Kwon's avatar
Woosuk Kwon committed
641
642
  T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
  T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
643
644
  CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
  CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
Woosuk Kwon's avatar
Woosuk Kwon committed
645
646
647
648
  int* block_tables_ptr = block_tables.data_ptr<int>();
  int* context_lens_ptr = context_lens.data_ptr<int>();

  constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
649
  int padded_max_context_len = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE;
650
  int logits_size = padded_max_context_len * sizeof(float);
Woosuk Kwon's avatar
Woosuk Kwon committed
651
  int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
652
653
  // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
  // Keep that in sync with the logic here!
Woosuk Kwon's avatar
Woosuk Kwon committed
654
655
  int shared_mem_size = std::max(logits_size, outputs_size);

656
657
  dim3 grid(num_heads, num_seqs, 1);
  dim3 block(NUM_THREADS);
658
  const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  switch (head_size) {
    // NOTE(woosuk): To reduce the compilation time, we only compile for the
    // head sizes that we use in the model. However, we can easily extend this
    // to support any head size which is a multiple of 16.
    case 64:
      LAUNCH_PAGED_ATTENTION_V1(64);
      break;
    case 80:
      LAUNCH_PAGED_ATTENTION_V1(80);
      break;
    case 96:
      LAUNCH_PAGED_ATTENTION_V1(96);
      break;
    case 112:
      LAUNCH_PAGED_ATTENTION_V1(112);
      break;
    case 128:
      LAUNCH_PAGED_ATTENTION_V1(128);
      break;
    case 256:
      LAUNCH_PAGED_ATTENTION_V1(256);
      break;
    default:
      TORCH_CHECK(false, "Unsupported head size: ", head_size);
      break;
  }
}

688
689
690
691
692
693
694
695
696
697
698
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE)       \
  paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE>( \
    out,                                                                     \
    query,                                                                   \
    key_cache,                                                               \
    value_cache,                                                             \
    num_kv_heads,                                                            \
    scale,                                                                   \
    block_tables,                                                            \
    context_lens,                                                            \
    max_context_len,                                                         \
699
700
701
702
    alibi_slopes);

// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256.
703
704
705
706
707
708
709
710
711
712
713
714
715
716
#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
  switch (block_size) {                                               \
    case 8:                                                           \
      CALL_V1_LAUNCHER(T, CACHE_T, 8, IS_FP8_E5M2_KV_CACHE);          \
      break;                                                          \
    case 16:                                                          \
      CALL_V1_LAUNCHER(T, CACHE_T, 16, IS_FP8_E5M2_KV_CACHE);         \
      break;                                                          \
    case 32:                                                          \
      CALL_V1_LAUNCHER(T, CACHE_T, 32, IS_FP8_E5M2_KV_CACHE);         \
      break;                                                          \
    default:                                                          \
      TORCH_CHECK(false, "Unsupported block size: ", block_size);     \
      break;                                                          \
717
718
719
720
721
722
723
  }

void paged_attention_v1(
  torch::Tensor& out,             // [num_seqs, num_heads, head_size]
  torch::Tensor& query,           // [num_seqs, num_heads, head_size]
  torch::Tensor& key_cache,       // [num_blocks, num_heads, head_size/x, block_size, x]
  torch::Tensor& value_cache,     // [num_blocks, num_heads, head_size, block_size]
724
  int num_kv_heads,               // [num_heads]
725
726
727
728
729
  float scale,
  torch::Tensor& block_tables,    // [num_seqs, max_num_blocks_per_seq]
  torch::Tensor& context_lens,    // [num_seqs]
  int block_size,
  int max_context_len,
730
731
732
733
734
735
736
  const c10::optional<torch::Tensor>& alibi_slopes,
  const std::string& kv_cache_dtype) {
  if (kv_cache_dtype == "auto") {
    if (query.dtype() == at::ScalarType::Float) {
      CALL_V1_LAUNCHER_BLOCK_SIZE(float, float, false);
    } else if (query.dtype() == at::ScalarType::Half) {
      CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false);
zhuwenwen's avatar
zhuwenwen committed
737
738
    // } else if (query.dtype() == at::ScalarType::BFloat16) {
    //   CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false);
739
740
741
    } else {
      TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
    }
zhuwenwen's avatar
zhuwenwen committed
742
743
744
745
746
747
748
749
750
751
  // } else if (kv_cache_dtype == "fp8_e5m2") {
  //   if (query.dtype() == at::ScalarType::Float) {
  //     CALL_V1_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
  //   } else if (query.dtype() == at::ScalarType::Half) {
  //     CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true);
  //   } else if (query.dtype() == at::ScalarType::BFloat16) {
  //     CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true);
  //   } else {
  //     TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
  //   }
752
  } else {
753
    TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
754
755
756
757
  }
}

#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE)                                                  \
758
759
  vllm::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,             \
  IS_FP8_E5M2_KV_CACHE, PARTITION_SIZE>                                                       \
760
761
762
763
764
765
766
  <<<grid, block, shared_mem_size, stream>>>(                                                 \
    exp_sums_ptr,                                                                             \
    max_logits_ptr,                                                                           \
    tmp_out_ptr,                                                                              \
    query_ptr,                                                                                \
    key_cache_ptr,                                                                            \
    value_cache_ptr,                                                                          \
767
    num_kv_heads,                                                                             \
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
    scale,                                                                                    \
    block_tables_ptr,                                                                         \
    context_lens_ptr,                                                                         \
    max_num_blocks_per_seq,                                                                   \
    alibi_slopes_ptr,                                                                         \
    q_stride,                                                                                 \
    kv_block_stride,                                                                          \
    kv_head_stride);                                                                          \
  vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, PARTITION_SIZE>           \
  <<<reduce_grid, block, reduce_shared_mem_size, stream>>>(                                   \
    out_ptr,                                                                                  \
    exp_sums_ptr,                                                                             \
    max_logits_ptr,                                                                           \
    tmp_out_ptr,                                                                              \
    context_lens_ptr,                                                                         \
    max_num_partitions);

template<
  typename T,
787
  typename CACHE_T,
788
  int BLOCK_SIZE,
789
  bool IS_FP8_E5M2_KV_CACHE,
790
791
792
793
794
795
796
797
798
799
  int NUM_THREADS = 128,
  int PARTITION_SIZE = 512>
void paged_attention_v2_launcher(
  torch::Tensor& out,
  torch::Tensor& exp_sums,
  torch::Tensor& max_logits,
  torch::Tensor& tmp_out,
  torch::Tensor& query,
  torch::Tensor& key_cache,
  torch::Tensor& value_cache,
800
  int num_kv_heads,
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
  float scale,
  torch::Tensor& block_tables,
  torch::Tensor& context_lens,
  int max_context_len,
  const c10::optional<torch::Tensor>& alibi_slopes) {
  int num_seqs = query.size(0);
  int num_heads = query.size(1);
  int head_size = query.size(2);
  int max_num_blocks_per_seq = block_tables.size(1);
  int q_stride = query.stride(0);
  int kv_block_stride = key_cache.stride(0);
  int kv_head_stride = key_cache.stride(1);

  int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
  assert(head_size % thread_group_size == 0);

  // NOTE: alibi_slopes is optional.
  const float* alibi_slopes_ptr = alibi_slopes ?
    reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
    : nullptr;

  T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
  float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
  float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
  T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
  T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
827
828
  CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
  CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
  int* block_tables_ptr = block_tables.data_ptr<int>();
  int* context_lens_ptr = context_lens.data_ptr<int>();

  constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
  int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE);
  int logits_size = PARTITION_SIZE * sizeof(float);
  int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);

  // For paged attention v2 kernel.
  dim3 grid(num_heads, num_seqs, max_num_partitions);
  int shared_mem_size = std::max(logits_size, outputs_size);
  // For paged attention v2 reduce kernel.
  dim3 reduce_grid(num_heads, num_seqs);
  int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);

Woosuk Kwon's avatar
Woosuk Kwon committed
844
  dim3 block(NUM_THREADS);
845
  const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
Woosuk Kwon's avatar
Woosuk Kwon committed
846
847
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  switch (head_size) {
848
849
850
    // NOTE(woosuk): To reduce the compilation time, we only compile for the
    // head sizes that we use in the model. However, we can easily extend this
    // to support any head size which is a multiple of 16.
Woosuk Kwon's avatar
Woosuk Kwon committed
851
    case 64:
852
      LAUNCH_PAGED_ATTENTION_V2(64);
Woosuk Kwon's avatar
Woosuk Kwon committed
853
854
      break;
    case 80:
855
      LAUNCH_PAGED_ATTENTION_V2(80);
Woosuk Kwon's avatar
Woosuk Kwon committed
856
857
      break;
    case 96:
858
      LAUNCH_PAGED_ATTENTION_V2(96);
Woosuk Kwon's avatar
Woosuk Kwon committed
859
      break;
860
    case 112:
861
      LAUNCH_PAGED_ATTENTION_V2(112);
862
      break;
Woosuk Kwon's avatar
Woosuk Kwon committed
863
    case 128:
864
      LAUNCH_PAGED_ATTENTION_V2(128);
Woosuk Kwon's avatar
Woosuk Kwon committed
865
      break;
866
    case 256:
867
      LAUNCH_PAGED_ATTENTION_V2(256);
868
      break;
Woosuk Kwon's avatar
Woosuk Kwon committed
869
870
871
872
873
874
    default:
      TORCH_CHECK(false, "Unsupported head size: ", head_size);
      break;
  }
}

875
876
877
878
879
880
881
882
883
884
885
886
887
888
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE)           \
  paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE>(     \
    out,                                                                         \
    exp_sums,                                                                    \
    max_logits,                                                                  \
    tmp_out,                                                                     \
    query,                                                                       \
    key_cache,                                                                   \
    value_cache,                                                                 \
    num_kv_heads,                                                                \
    scale,                                                                       \
    block_tables,                                                                \
    context_lens,                                                                \
    max_context_len,                                                             \
Woosuk Kwon's avatar
Woosuk Kwon committed
889
    alibi_slopes);
Woosuk Kwon's avatar
Woosuk Kwon committed
890

Woosuk Kwon's avatar
Woosuk Kwon committed
891
892
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256.
893
894
895
896
897
898
899
900
901
902
903
904
905
906
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_E5M2_KV_CACHE)       \
  switch (block_size) {                                                     \
    case 8:                                                                 \
      CALL_V2_LAUNCHER(T, CACHE_T, 8, IS_FP8_E5M2_KV_CACHE);                \
      break;                                                                \
    case 16:                                                                \
      CALL_V2_LAUNCHER(T, CACHE_T, 16, IS_FP8_E5M2_KV_CACHE);               \
      break;                                                                \
    case 32:                                                                \
      CALL_V2_LAUNCHER(T, CACHE_T, 32, IS_FP8_E5M2_KV_CACHE);               \
      break;                                                                \
    default:                                                                \
      TORCH_CHECK(false, "Unsupported block size: ", block_size);           \
      break;                                                                \
Woosuk Kwon's avatar
Woosuk Kwon committed
907
908
  }

909
void paged_attention_v2(
Woosuk Kwon's avatar
Woosuk Kwon committed
910
  torch::Tensor& out,             // [num_seqs, num_heads, head_size]
911
912
913
  torch::Tensor& exp_sums,        // [num_seqs, num_heads, max_num_partitions]
  torch::Tensor& max_logits,      // [num_seqs, num_heads, max_num_partitions]
  torch::Tensor& tmp_out,         // [num_seqs, num_heads, max_num_partitions, head_size]
Woosuk Kwon's avatar
Woosuk Kwon committed
914
915
916
  torch::Tensor& query,           // [num_seqs, num_heads, head_size]
  torch::Tensor& key_cache,       // [num_blocks, num_heads, head_size/x, block_size, x]
  torch::Tensor& value_cache,     // [num_blocks, num_heads, head_size, block_size]
917
  int num_kv_heads,               // [num_heads]
Woosuk Kwon's avatar
Woosuk Kwon committed
918
919
920
921
  float scale,
  torch::Tensor& block_tables,    // [num_seqs, max_num_blocks_per_seq]
  torch::Tensor& context_lens,    // [num_seqs]
  int block_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
922
  int max_context_len,
923
924
925
926
927
928
929
  const c10::optional<torch::Tensor>& alibi_slopes,
  const std::string& kv_cache_dtype) {
  if (kv_cache_dtype == "auto") {
    if (query.dtype() == at::ScalarType::Float) {
      CALL_V2_LAUNCHER_BLOCK_SIZE(float, float, false);
    } else if (query.dtype() == at::ScalarType::Half) {
      CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false);
zhuwenwen's avatar
zhuwenwen committed
930
931
    // } else if (query.dtype() == at::ScalarType::BFloat16) {
    //   CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false);
932
933
934
    } else {
      TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
    }
zhuwenwen's avatar
zhuwenwen committed
935
936
937
938
939
940
941
942
943
944
  // } else if (kv_cache_dtype == "fp8_e5m2") {
  //   if (query.dtype() == at::ScalarType::Float) {
  //     CALL_V2_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
  //   } else if (query.dtype() == at::ScalarType::Half) {
  //     CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true);
  //   } else if (query.dtype() == at::ScalarType::BFloat16) {
  //     CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true);
  //   } else {
  //     TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
  //   }
Woosuk Kwon's avatar
Woosuk Kwon committed
945
  } else {
946
    TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
Woosuk Kwon's avatar
Woosuk Kwon committed
947
948
949
950
951
952
  }
}

#undef WARP_SIZE
#undef MAX
#undef MIN
953
#undef DIVIDE_ROUND_UP