attention_kernels.cuh 28.7 KB
Newer Older
zhuwenwen's avatar
zhuwenwen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
/*
 * Adapted from
 * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
 * Copyright (c) 2023, The vLLM team.
 * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

20
#include <torch/all.h>
Woosuk Kwon's avatar
Woosuk Kwon committed
21
#include <ATen/cuda/CUDAContext.h>
22
#include <c10/cuda/CUDAGuard.h>
23
#include <algorithm>
Woosuk Kwon's avatar
Woosuk Kwon committed
24

Woosuk Kwon's avatar
Woosuk Kwon committed
25
#include "attention_dtypes.h"
Woosuk Kwon's avatar
Woosuk Kwon committed
26
#include "attention_utils.cuh"
27
28
29

#ifdef USE_ROCM
  #include <hip/hip_bf16.h>
30
  #include "../quantization/fp8/amd/quant_utils.cuh"
31
typedef __hip_bfloat16 __nv_bfloat16;
32
33
#else
  #include "../quantization/fp8/nvidia/quant_utils.cuh"
34
35
#endif

xiabo's avatar
xiabo committed
36
37
#include "../quantization/int8_kvcache/quant_utils.cuh"

38
#ifndef USE_ROCM
39
  #define WARP_SIZE 32
40
#else
41
  #define WARP_SIZE warpSize
42
43
#endif

Woosuk Kwon's avatar
Woosuk Kwon committed
44
45
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
46
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
Woosuk Kwon's avatar
Woosuk Kwon committed
47

Woosuk Kwon's avatar
Woosuk Kwon committed
48
namespace vllm {
Woosuk Kwon's avatar
Woosuk Kwon committed
49
50

// Utility function for attention softmax.
51
template <int NUM_WARPS>
Woosuk Kwon's avatar
Woosuk Kwon committed
52
53
54
55
56
57
58
59
inline __device__ float block_sum(float* red_smem, float sum) {
  // Decompose the thread index into warp / lane.
  int warp = threadIdx.x / WARP_SIZE;
  int lane = threadIdx.x % WARP_SIZE;

  // Compute the sum per warp.
#pragma unroll
  for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
60
    sum += VLLM_SHFL_XOR_SYNC(sum, mask);
Woosuk Kwon's avatar
Woosuk Kwon committed
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
  }

  // Warp leaders store the data to shared memory.
  if (lane == 0) {
    red_smem[warp] = sum;
  }

  // Make sure the data is in shared memory.
  __syncthreads();

  // The warps compute the final sums.
  if (lane < NUM_WARPS) {
    sum = red_smem[lane];
  }

  // Parallel reduction inside the warp.
#pragma unroll
  for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
79
    sum += VLLM_SHFL_XOR_SYNC(sum, mask);
Woosuk Kwon's avatar
Woosuk Kwon committed
80
81
82
  }

  // Broadcast to other threads.
83
  return VLLM_SHFL_SYNC(sum, 0);
Woosuk Kwon's avatar
Woosuk Kwon committed
84
85
}

zhuwenwen's avatar
zhuwenwen committed
86
87
// TODO(woosuk): Merge the last two dimensions of the grid.
// Grid: (num_heads, num_seqs, max_num_partitions).
88
89
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
          int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
zhuwenwen's avatar
zhuwenwen committed
90
          bool IS_BLOCK_SPARSE,
zhangshao's avatar
zhangshao committed
91
          int PARTITION_SIZE = 0>  // Zero means no partitioning.
zhuwenwen's avatar
zhuwenwen committed
92
__device__ void paged_attention_kernel(
zhangshao's avatar
zhangshao committed
93
94
95
96
97
98
99
100
101
102
    float* __restrict__ exp_sums,  // [num_seqs, num_heads, max_num_partitions]
    float* __restrict__ max_logits,  // [num_seqs, num_heads,
                                     // max_num_partitions]
    scalar_t* __restrict__ out,  // [num_seqs, num_heads, max_num_partitions,
                                 // head_size]
    const scalar_t* __restrict__ q,       // [num_seqs, num_heads, head_size]
    const cache_t* __restrict__ k_cache,  // [num_blocks, num_kv_heads,
                                          // head_size/x, block_size, x]
    const cache_t* __restrict__ v_cache,  // [num_blocks, num_kv_heads,
                                          // head_size, block_size]
zhuwenwen's avatar
zhuwenwen committed
103
    const int num_kv_heads,               // [num_heads]
104
105
106
107
108
109
    const float scale,
    const int* __restrict__ block_tables,  // [num_seqs, max_num_blocks_per_seq]
    const int* __restrict__ seq_lens,      // [num_seqs]
    const int max_num_blocks_per_seq,
    const float* __restrict__ alibi_slopes,  // [num_heads]
    const int q_stride, const int kv_block_stride, const int kv_head_stride,
110
    const float* k_scale, const float* v_scale, const int tp_rank,
zhuwenwen's avatar
zhuwenwen committed
111
    const int blocksparse_local_blocks, const int blocksparse_vert_stride,
112
    const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
zhuwenwen's avatar
zhuwenwen committed
113
114
115
  const int seq_idx = blockIdx.y;
  const int partition_idx = blockIdx.z;
  const int max_num_partitions = gridDim.z;
116
117
118
119
120
121
122
123
124
125
  constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0;
  const int seq_len = seq_lens[seq_idx];
  if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= seq_len) {
    // No work to do. Terminate the thread block.
    return;
  }

  const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
  const int num_blocks_per_partition =
      USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks;
zhuwenwen's avatar
zhuwenwen committed
126

127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
  // [start_block_idx, end_block_idx) is the range of blocks to process.
  const int start_block_idx =
      USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
  const int end_block_idx =
      MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks);
  const int num_blocks = end_block_idx - start_block_idx;

  // [start_token_idx, end_token_idx) is the range of tokens to process.
  const int start_token_idx = start_block_idx * BLOCK_SIZE;
  const int end_token_idx =
      MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len);
  const int num_tokens = end_token_idx - start_token_idx;

  constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
  constexpr int NUM_THREAD_GROUPS =
      NUM_THREADS / THREAD_GROUP_SIZE;  // Note: This assumes THREAD_GROUP_SIZE
                                        // divides NUM_THREADS
  assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
  constexpr int NUM_TOKENS_PER_THREAD_GROUP =
      DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE);
  constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
  const int thread_idx = threadIdx.x;
zhuwenwen's avatar
zhuwenwen committed
149
  const int warp_idx = thread_idx / WARP_SIZE;
150
151
  const int lane = thread_idx % WARP_SIZE;

zhuwenwen's avatar
zhuwenwen committed
152
153
  const int head_idx = blockIdx.x;
  const int num_heads = gridDim.x;
154
  const int num_queries_per_kv = num_heads / num_kv_heads;
zhuwenwen's avatar
zhuwenwen committed
155
156
157
  const int kv_head_idx = head_idx / num_queries_per_kv;
  const float alibi_slope =
      alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
158
159
160
161
162
163

  // A vector type to store a part of a key or a query.
  // The vector size is configured in such a way that the threads in a thread
  // group fetch or compute 16 bytes at a time. For example, if the size of a
  // thread group is 4 and the data type is half, then the vector size is 16 /
  // (4 * sizeof(half)) == 2.
zhuwenwen's avatar
zhuwenwen committed
164
  constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
165
166
167
168
169
170
171
172
173
174
175
176
  using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
  using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
  using Quant_vec = typename Vec<cache_t, VEC_SIZE>::Type;

  constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
  constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;

  const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE;
  const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE;

  // Load the query to registers.
  // Each thread in a thread group has a different part of the query.
omahs's avatar
omahs committed
177
  // For example, if the thread group size is 4, then the first thread in
178
179
180
  // the group has 0, 4, 8, ... th vectors of the query, and the second thread
  // has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because
  // q is split from a qkv tensor, it may not be contiguous.
zhuwenwen's avatar
zhuwenwen committed
181
182
183
184
185
186
187
188
189
190
191
  const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
  __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
#pragma unroll
  for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD;
       i += NUM_THREAD_GROUPS) {
    const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
    q_vecs[thread_group_offset][i] =
        *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
  }
  __syncthreads();  // TODO(naed90): possible speedup if this is replaced with a
                    // memory wall right before we use q_vecs
192
193
194
195
196
197

  // Memory planning.
  extern __shared__ char shared_mem[];
  // NOTE(woosuk): We use FP32 for the softmax logits for better accuracy.
  float* logits = reinterpret_cast<float*>(shared_mem);
  // Workspace for reduction.
zhuwenwen's avatar
zhuwenwen committed
198
  __shared__ float red_smem[2 * NUM_WARPS];
199
200
201
202

  // x == THREAD_GROUP_SIZE * VEC_SIZE
  // Each thread group fetches x elements from the key at a time.
  constexpr int x = 16 / sizeof(cache_t);
zhuwenwen's avatar
zhuwenwen committed
203
  float qk_max = -FLT_MAX;
204
205
206
207
208
209

  // Iterate over the key blocks.
  // Each warp fetches a block of keys for each iteration.
  // Each thread group in a warp fetches a key from the block, and computes
  // dot product with the query.
  const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
zhuwenwen's avatar
zhuwenwen committed
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230

  // blocksparse specific vars
  int bs_block_offset;
  int q_bs_block_id;
  if constexpr (IS_BLOCK_SPARSE) {
    // const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len,
    // blocksparse_block_size);
    q_bs_block_id = (seq_len - 1) / blocksparse_block_size;
    if (blocksparse_head_sliding_step >= 0)
      // sliding on q heads
      bs_block_offset =
          (tp_rank * num_heads + head_idx) * blocksparse_head_sliding_step + 1;
    else
      // sliding on kv heads
      bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) *
                            (-blocksparse_head_sliding_step) +
                        1;
  }

  for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
       block_idx += NUM_WARPS) {
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
    // NOTE(woosuk): The block number is stored in int32. However, we cast it to
    // int64 because int32 can lead to overflow when this variable is multiplied
    // by large numbers (e.g., kv_block_stride).
    // For blocksparse attention: skip computation on blocks that are not
    // attended
    if constexpr (IS_BLOCK_SPARSE) {
      const int k_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size;
      const bool is_remote =
          ((k_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0);
      const bool is_local =
          (k_bs_block_id > q_bs_block_id - blocksparse_local_blocks);
      if (!is_remote && !is_local) {
        for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
          const int physical_block_offset =
              (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
          const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;

          if (thread_group_offset == 0) {
            // NOTE(linxihui): assign very large number to skipped tokens to
            // avoid contribution to the sumexp softmax normalizer. This will
            // not be used at computing sum(softmax*v) as the blocks will be
            // skipped.
            logits[token_idx - start_token_idx] = -FLT_MAX;
          }
        }
        continue;
      }
    }
zhuwenwen's avatar
zhuwenwen committed
259
260
    const int64_t physical_block_number =
        static_cast<int64_t>(block_table[block_idx]);
261
262
263

    // Load a key to registers.
    // Each thread in a thread group has a different part of the key.
omahs's avatar
omahs committed
264
    // For example, if the thread group size is 4, then the first thread in
265
266
267
    // the group has 0, 4, 8, ... th vectors of the key, and the second thread
    // has 1, 5, 9, ... th vectors of the key, and so on.
    for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
zhuwenwen's avatar
zhuwenwen committed
268
269
      const int physical_block_offset =
          (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
270
271
      const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
      K_vec k_vecs[NUM_VECS_PER_THREAD];
zhuwenwen's avatar
zhuwenwen committed
272
273
274
275
276
277
278
279
280
281
282
283
284

#pragma unroll
      for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
        const cache_t* k_ptr =
            k_cache + physical_block_number * kv_block_stride +
            kv_head_idx * kv_head_stride + physical_block_offset * x;
        const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
        const int offset1 = (vec_idx * VEC_SIZE) / x;
        const int offset2 = (vec_idx * VEC_SIZE) % x;

        if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
          k_vecs[j] = *reinterpret_cast<const K_vec*>(
              k_ptr + offset1 * BLOCK_SIZE * x + offset2);
xiabo's avatar
xiabo committed
285
286
287
288
289
290
291
        } else if constexpr (KV_DTYPE == Fp8KVCacheDataType::kInt8) {
              Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
                  k_ptr + offset1 * BLOCK_SIZE * x + offset2);
              k_vecs[j] = int8::scaled_vec_conversion_int8<K_vec, Quant_vec>(
                  k_vec_quant, 
                  *k_scale);
        } else { 
zhuwenwen's avatar
zhuwenwen committed
292
293
294
295
          // Vector conversion from Quant_vec to K_vec.
          Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
              k_ptr + offset1 * BLOCK_SIZE * x + offset2);
          k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(
296
              k_vec_quant, *k_scale);
297
        }
Woosuk Kwon's avatar
Woosuk Kwon committed
298
      }
zhuwenwen's avatar
zhuwenwen committed
299

Woosuk Kwon's avatar
Woosuk Kwon committed
300
301
      // Compute dot product.
      // This includes a reduction across the threads in the same thread group.
zhuwenwen's avatar
zhuwenwen committed
302
303
      float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(
                             q_vecs[thread_group_offset], k_vecs);
Woosuk Kwon's avatar
Woosuk Kwon committed
304
      // Add the ALiBi bias if slopes are given.
305
      qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0;
zhuwenwen's avatar
zhuwenwen committed
306

Woosuk Kwon's avatar
Woosuk Kwon committed
307
308
309
      if (thread_group_offset == 0) {
        // Store the partial reductions to shared memory.
        // NOTE(woosuk): It is required to zero out the masked logits.
310
        const bool mask = token_idx >= seq_len;
311
        logits[token_idx - start_token_idx] = mask ? 0.f : qk;
Woosuk Kwon's avatar
Woosuk Kwon committed
312
        // Update the max value.
zhuwenwen's avatar
zhuwenwen committed
313
        qk_max = mask ? qk_max : fmaxf(qk_max, qk);
Woosuk Kwon's avatar
Woosuk Kwon committed
314
315
316
317
318
319
320
      }
    }
  }

  // Perform reduction across the threads in the same warp to get the
  // max qk value for each "warp" (not across the thread block yet).
  // The 0-th thread of each thread group already has its max qk value.
zhuwenwen's avatar
zhuwenwen committed
321
322
323
324
325
326
327
328
#pragma unroll
  for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
    qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
  }
  if (lane == 0) {
    red_smem[warp_idx] = qk_max;
  }
  __syncthreads();
Woosuk Kwon's avatar
Woosuk Kwon committed
329

zhuwenwen's avatar
zhuwenwen committed
330
331
332
333
334
335
336
337
338
  // TODO(woosuk): Refactor this part.
  // Get the max qk value for the sequence.
  qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
#pragma unroll
  for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
    qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
  }
  // Broadcast the max qk value to all threads.
  qk_max = VLLM_SHFL_SYNC(qk_max, 0);
Woosuk Kwon's avatar
Woosuk Kwon committed
339

zhuwenwen's avatar
zhuwenwen committed
340
341
342
343
344
345
346
347
  // Get the sum of the exp values.
  float exp_sum = 0.f;
  for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
    float val = __expf(logits[i] - qk_max);
    logits[i] = val;
    exp_sum += val;
  }
  exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
Woosuk Kwon's avatar
Woosuk Kwon committed
348

zhuwenwen's avatar
zhuwenwen committed
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
  // Compute softmax.
  const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
  for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
    logits[i] *= inv_sum;
  }
  __syncthreads();

  // If partitioning is enabled, store the max logit and exp_sum.
  if (USE_PARTITIONING && thread_idx == 0) {
    float* max_logits_ptr = max_logits +
                            seq_idx * num_heads * max_num_partitions +
                            head_idx * max_num_partitions + partition_idx;
    *max_logits_ptr = qk_max;
    float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions +
                          head_idx * max_num_partitions + partition_idx;
    *exp_sums_ptr = exp_sum;
365
  }
zhuwenwen's avatar
zhuwenwen committed
366

Woosuk Kwon's avatar
Woosuk Kwon committed
367
368
369
370
  // Each thread will fetch 16 bytes from the value cache at a time.
  constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
  using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
  using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
371
  using V_quant_vec = typename Vec<cache_t, V_VEC_SIZE>::Type;
Woosuk Kwon's avatar
Woosuk Kwon committed
372
373
374
375
  using Float_L_vec = typename FloatVec<L_vec>::Type;

  constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
  constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
376
377
  constexpr int NUM_ROWS_PER_THREAD =
      DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
Woosuk Kwon's avatar
Woosuk Kwon committed
378
379

  // NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
zhuwenwen's avatar
zhuwenwen committed
380
381
382
383
  float accs[NUM_ROWS_PER_THREAD];
#pragma unroll
  for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
    accs[i] = 0.f;
zhangshao's avatar
zhangshao committed
384
  }
zhuwenwen's avatar
zhuwenwen committed
385

386
387
  scalar_t zero_value;
  zero(zero_value);
388
389
  for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
       block_idx += NUM_WARPS) {
zhuwenwen's avatar
zhuwenwen committed
390
391
392
393
394
395
396
397
398
399
400
401
    // NOTE(woosuk): The block number is stored in int32. However, we cast it to
    // int64 because int32 can lead to overflow when this variable is multiplied
    // by large numbers (e.g., kv_block_stride).
    // For blocksparse attention: skip computation on blocks that are not
    // attended
    if constexpr (IS_BLOCK_SPARSE) {
      int v_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size;
      if (!((v_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0) &&
          !((v_bs_block_id > q_bs_block_id - blocksparse_local_blocks))) {
        continue;
      }
    }
402
403
    const int64_t physical_block_number =
        static_cast<int64_t>(block_table[block_idx]);
Woosuk Kwon's avatar
Woosuk Kwon committed
404
405
406
    const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
    const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
    L_vec logits_vec;
zhuwenwen's avatar
zhuwenwen committed
407
408
409
410
411
    from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx -
                                                           start_token_idx));

    const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
                           kv_head_idx * kv_head_stride;
Woosuk Kwon's avatar
Woosuk Kwon committed
412
413
414
415
416
#pragma unroll
    for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
      const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
      if (row_idx < HEAD_SIZE) {
        const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
zhuwenwen's avatar
zhuwenwen committed
417
        V_vec v_vec;
418
419
420

        if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
          v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
xiabo's avatar
xiabo committed
421
422
423
424
425
426
        } else if constexpr (KV_DTYPE == Fp8KVCacheDataType::kInt8) {
          V_quant_vec v_quant_vec =
              *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
          // Vector conversion from V_quant_vec to V_vec.
          v_vec = int8::scaled_vec_conversion_int8<V_vec, V_quant_vec>(v_quant_vec,
                                                                       *v_scale);
427
        } else {
428
429
          V_quant_vec v_quant_vec =
              *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
430
          // Vector conversion from V_quant_vec to V_vec.
431
          v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec,
432
                                                                    *v_scale);
433
        }
434
        if (block_idx == num_seq_blocks - 1) {
435
436
437
438
          // NOTE(woosuk): When v_vec contains the tokens that are out of the
          // context, we should explicitly zero out the values since they may
          // contain NaNs. See
          // https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
439
440
          scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
#pragma unroll
441
          for (int j = 0; j < V_VEC_SIZE; j++) {
442
            v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value;
443
444
          }
        }
zhuwenwen's avatar
zhuwenwen committed
445
        accs[i] += dot(logits_vec, v_vec);
Woosuk Kwon's avatar
Woosuk Kwon committed
446
447
448
449
450
451
452
      }
    }
  }

  // Perform reduction within each warp.
#pragma unroll
  for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
zhuwenwen's avatar
zhuwenwen committed
453
    float acc = accs[i];
Woosuk Kwon's avatar
Woosuk Kwon committed
454
455
#pragma unroll
    for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
456
      acc += VLLM_SHFL_XOR_SYNC(acc, mask);
Woosuk Kwon's avatar
Woosuk Kwon committed
457
    }
zhuwenwen's avatar
zhuwenwen committed
458
    accs[i] = acc;
Woosuk Kwon's avatar
Woosuk Kwon committed
459
460
  }

461
462
  // NOTE(woosuk): A barrier is required because the shared memory space for
  // logits is reused for the output.
Woosuk Kwon's avatar
Woosuk Kwon committed
463
464
465
466
467
468
469
470
471
  __syncthreads();

  // Perform reduction across warps.
  float* out_smem = reinterpret_cast<float*>(shared_mem);
#pragma unroll
  for (int i = NUM_WARPS; i > 1; i /= 2) {
    int mid = i / 2;
    // Upper warps write to shared memory.
    if (warp_idx >= mid && warp_idx < i) {
zhuwenwen's avatar
zhuwenwen committed
472
      float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
Woosuk Kwon's avatar
Woosuk Kwon committed
473
474
475
476
#pragma unroll
      for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
        const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
        if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
zhuwenwen's avatar
zhuwenwen committed
477
          dst[row_idx] = accs[i];
Woosuk Kwon's avatar
Woosuk Kwon committed
478
479
480
481
482
483
484
        }
      }
    }
    __syncthreads();

    // Lower warps update the output.
    if (warp_idx < mid) {
zhuwenwen's avatar
zhuwenwen committed
485
      const float* src = &out_smem[warp_idx * HEAD_SIZE];
Woosuk Kwon's avatar
Woosuk Kwon committed
486
487
488
489
#pragma unroll
      for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
        const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
        if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
zhuwenwen's avatar
zhuwenwen committed
490
          accs[i] += src[row_idx];
Woosuk Kwon's avatar
Woosuk Kwon committed
491
492
493
494
495
496
497
498
        }
      }
    }
    __syncthreads();
  }

  // Write the final output.
  if (warp_idx == 0) {
499
500
501
    scalar_t* out_ptr =
        out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
        head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE;
Woosuk Kwon's avatar
Woosuk Kwon committed
502
#pragma unroll
zhuwenwen's avatar
zhuwenwen committed
503
504
505
506
    for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
      const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
      if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
        from_float(*(out_ptr + row_idx), accs[i]);
Woosuk Kwon's avatar
Woosuk Kwon committed
507
508
509
510
511
      }
    }
  }
}

512
// Grid: (num_heads, num_seqs, 1).
513
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
514
          int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
zhuwenwen's avatar
zhuwenwen committed
515
          bool IS_BLOCK_SPARSE>
zhuwenwen's avatar
zhuwenwen committed
516
__global__ void paged_attention_v1_kernel(
517
518
519
520
521
522
523
524
525
526
527
528
529
    scalar_t* __restrict__ out,           // [num_seqs, num_heads, head_size]
    const scalar_t* __restrict__ q,       // [num_seqs, num_heads, head_size]
    const cache_t* __restrict__ k_cache,  // [num_blocks, num_kv_heads,
                                          // head_size/x, block_size, x]
    const cache_t* __restrict__ v_cache,  // [num_blocks, num_kv_heads,
                                          // head_size, block_size]
    const int num_kv_heads,               // [num_heads]
    const float scale,
    const int* __restrict__ block_tables,  // [num_seqs, max_num_blocks_per_seq]
    const int* __restrict__ seq_lens,      // [num_seqs]
    const int max_num_blocks_per_seq,
    const float* __restrict__ alibi_slopes,  // [num_heads]
    const int q_stride, const int kv_block_stride, const int kv_head_stride,
530
    const float* k_scale, const float* v_scale, const int tp_rank,
531
    const int blocksparse_local_blocks, const int blocksparse_vert_stride,
532
    const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
zhuwenwen's avatar
zhuwenwen committed
533
  paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
zhuwenwen's avatar
zhuwenwen committed
534
                         KV_DTYPE, IS_BLOCK_SPARSE>(
535
      /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache,
zhuwenwen's avatar
zhuwenwen committed
536
      v_cache, num_kv_heads, scale, block_tables, seq_lens,
537
      max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride,
538
      kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks,
539
      blocksparse_vert_stride, blocksparse_block_size,
540
      blocksparse_head_sliding_step);
541
542
543
}

// Grid: (num_heads, num_seqs, max_num_partitions).
544
545
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
          int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
546
          bool IS_BLOCK_SPARSE,
zhuwenwen's avatar
zhuwenwen committed
547
          int PARTITION_SIZE>
zhuwenwen's avatar
zhuwenwen committed
548
__global__ void paged_attention_v2_kernel(
549
550
551
552
553
554
555
556
557
558
    float* __restrict__ exp_sums,  // [num_seqs, num_heads, max_num_partitions]
    float* __restrict__ max_logits,       // [num_seqs, num_heads,
                                          // max_num_partitions]
    scalar_t* __restrict__ tmp_out,       // [num_seqs, num_heads,
                                          // max_num_partitions, head_size]
    const scalar_t* __restrict__ q,       // [num_seqs, num_heads, head_size]
    const cache_t* __restrict__ k_cache,  // [num_blocks, num_kv_heads,
                                          // head_size/x, block_size, x]
    const cache_t* __restrict__ v_cache,  // [num_blocks, num_kv_heads,
                                          // head_size, block_size]
zhuwenwen's avatar
zhuwenwen committed
559
    const int num_kv_heads,               // [num_heads]
560
561
562
563
564
565
    const float scale,
    const int* __restrict__ block_tables,  // [num_seqs, max_num_blocks_per_seq]
    const int* __restrict__ seq_lens,      // [num_seqs]
    const int max_num_blocks_per_seq,
    const float* __restrict__ alibi_slopes,  // [num_heads]
    const int q_stride, const int kv_block_stride, const int kv_head_stride,
566
    const float* k_scale, const float* v_scale, const int tp_rank,
567
    const int blocksparse_local_blocks, const int blocksparse_vert_stride,
568
    const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
zhuwenwen's avatar
zhuwenwen committed
569
  paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
zhuwenwen's avatar
zhuwenwen committed
570
571
                         KV_DTYPE, IS_BLOCK_SPARSE, PARTITION_SIZE>(
      exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
572
      block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride,
573
      kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank,
574
      blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size,
575
      blocksparse_head_sliding_step);
576
577
578
}

// Grid: (num_heads, num_seqs).
579
580
template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS,
          int PARTITION_SIZE>
zhuwenwen's avatar
zhuwenwen committed
581
__global__ void paged_attention_v2_reduce_kernel(
582
583
584
585
586
587
588
589
590
    scalar_t* __restrict__ out,            // [num_seqs, num_heads, head_size]
    const float* __restrict__ exp_sums,    // [num_seqs, num_heads,
                                           // max_num_partitions]
    const float* __restrict__ max_logits,  // [num_seqs, num_heads,
                                           // max_num_partitions]
    const scalar_t* __restrict__ tmp_out,  // [num_seqs, num_heads,
                                           // max_num_partitions, head_size]
    const int* __restrict__ seq_lens,      // [num_seqs]
    const int max_num_partitions) {
591
592
593
  const int num_heads = gridDim.x;
  const int head_idx = blockIdx.x;
  const int seq_idx = blockIdx.y;
594
595
  const int seq_len = seq_lens[seq_idx];
  const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE);
596
597
  if (num_partitions == 1) {
    // No need to reduce. Only copy tmp_out to out.
598
599
600
601
602
    scalar_t* out_ptr =
        out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
    const scalar_t* tmp_out_ptr =
        tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
        head_idx * max_num_partitions * HEAD_SIZE;
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
    for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) {
      out_ptr[i] = tmp_out_ptr[i];
    }
    // Terminate the thread block.
    return;
  }

  constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
  const int warp_idx = threadIdx.x / WARP_SIZE;
  const int lane = threadIdx.x % WARP_SIZE;

  // Size: 2 * num_partitions.
  extern __shared__ char shared_mem[];
  // Workspace for reduction.
  __shared__ float red_smem[2 * NUM_WARPS];

  // Load max logits to shared memory.
  float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
621
622
623
  const float* max_logits_ptr = max_logits +
                                seq_idx * num_heads * max_num_partitions +
                                head_idx * max_num_partitions;
624
625
626
627
628
629
630
631
632
633
634
635
  float max_logit = -FLT_MAX;
  for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
    const float l = max_logits_ptr[i];
    shared_max_logits[i] = l;
    max_logit = fmaxf(max_logit, l);
  }
  __syncthreads();

  // Get the global max logit.
  // Reduce within the warp.
#pragma unroll
  for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
636
    max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
637
638
639
640
641
642
643
644
645
  }
  if (lane == 0) {
    red_smem[warp_idx] = max_logit;
  }
  __syncthreads();
  // Reduce across warps.
  max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
#pragma unroll
  for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
646
    max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
647
648
  }
  // Broadcast the max value to all threads.
649
  max_logit = VLLM_SHFL_SYNC(max_logit, 0);
650
651

  // Load rescaled exp sums to shared memory.
652
653
654
655
656
  float* shared_exp_sums =
      reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
  const float* exp_sums_ptr = exp_sums +
                              seq_idx * num_heads * max_num_partitions +
                              head_idx * max_num_partitions;
657
658
659
660
661
662
663
664
665
666
667
668
  float global_exp_sum = 0.0f;
  for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
    float l = shared_max_logits[i];
    float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit);
    global_exp_sum += rescaled_exp_sum;
    shared_exp_sums[i] = rescaled_exp_sum;
  }
  __syncthreads();
  global_exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], global_exp_sum);
  const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);

  // Aggregate tmp_out to out.
669
670
671
672
673
  const scalar_t* tmp_out_ptr =
      tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
      head_idx * max_num_partitions * HEAD_SIZE;
  scalar_t* out_ptr =
      out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
674
675
676
677
#pragma unroll
  for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
    float acc = 0.0f;
    for (int j = 0; j < num_partitions; ++j) {
678
679
      acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] *
             inv_global_exp_sum;
680
681
682
683
684
    }
    from_float(out_ptr[i], acc);
  }
}

685
686
}  // namespace vllm

Woosuk Kwon's avatar
Woosuk Kwon committed
687
688
689
#undef WARP_SIZE
#undef MAX
#undef MIN
zhuwenwen's avatar
zhuwenwen committed
690
#undef DIVIDE_ROUND_UP