"docs/source/features/multimodal_inputs.md" did not exist on "843b222723b659e4b80d71d3ffb4944266af1d74"
cache_kernels.cu 70.8 KB
Newer Older
1
#include <torch/all.h>
Woosuk Kwon's avatar
Woosuk Kwon committed
2
#include <ATen/cuda/CUDAContext.h>
3
#include <c10/cuda/CUDAGuard.h>
4
#include <c10/cuda/CUDAException.h>
Woosuk Kwon's avatar
Woosuk Kwon committed
5

6
#include "cuda_utils.h"
7
#include "cuda_compat.h"
8
#include "dispatch_utils.h"
9
#include "quantization/vectorization_utils.cuh"
10
11

#ifdef USE_ROCM
12
  #include "quantization/fp8/amd/quant_utils.cuh"
13
#else
14
  #include "quantization/fp8/nvidia/quant_utils.cuh"
15
#endif
16

xiabo's avatar
xiabo committed
17
18
#include "quantization/int8_kvcache/quant_utils.cuh"

Woosuk Kwon's avatar
Woosuk Kwon committed
19
#include <algorithm>
Woosuk Kwon's avatar
Woosuk Kwon committed
20
#include <cassert>
21
#include <cfloat>  // FLT_MIN
Woosuk Kwon's avatar
Woosuk Kwon committed
22
#include <map>
23
#include <vector>
zhuwenwen's avatar
zhuwenwen committed
24
#include <ATen/cuda/CUDAContext.h>
Woosuk Kwon's avatar
Woosuk Kwon committed
25

26
27
#ifdef USE_ROCM
  #include <hip/hip_bf16.h>
28
typedef __hip_bfloat16 __nv_bfloat16;
29
30
#endif

31
32
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
                 const torch::Tensor& block_mapping) {
Woosuk Kwon's avatar
Woosuk Kwon committed
33
34
35
36
  torch::Device src_device = src.device();
  torch::Device dst_device = dst.device();
  cudaMemcpyKind memcpy_type;
  if (src_device.is_cuda() && dst_device.is_cuda()) {
37
38
    TORCH_CHECK(src_device.index() == dst_device.index(),
                "src and dst must be on the same GPU");
Woosuk Kwon's avatar
Woosuk Kwon committed
39
40
41
42
43
44
    memcpy_type = cudaMemcpyDeviceToDevice;
  } else if (src_device.is_cuda() && dst_device.is_cpu()) {
    memcpy_type = cudaMemcpyDeviceToHost;
  } else if (src_device.is_cpu() && dst_device.is_cuda()) {
    memcpy_type = cudaMemcpyHostToDevice;
  } else {
Woosuk Kwon's avatar
Woosuk Kwon committed
45
    TORCH_CHECK(false, "Invalid device combination");
Woosuk Kwon's avatar
Woosuk Kwon committed
46
47
  }

48
  // NOTE(youkaichao): keep in mind that `block_mapping` should be
49
50
51
52
  // a cpu tensor, otherwise every `item` call will require a gpu-cpu
  // synchronization.
  TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU");

53
54
  char* src_ptr = static_cast<char*>(src.data_ptr());
  char* dst_ptr = static_cast<char*>(dst.data_ptr());
Woosuk Kwon's avatar
Woosuk Kwon committed
55

56
57
58
59
  // We use the stride instead of numel in case the cache is padded for memory
  // alignment reasons, we assume the blocks data (inclusive of any padding)
  // is contiguous in memory
  const int64_t block_size_in_bytes = src.element_size() * src.stride(0);
60
61
  const at::cuda::OptionalCUDAGuard device_guard(
      src_device.is_cuda() ? src_device : dst_device);
Woosuk Kwon's avatar
Woosuk Kwon committed
62
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
Woosuk Kwon's avatar
Woosuk Kwon committed
63
  // NOTE(woosuk): This can be slow if the number of blocks is large.
64
65
66
67
  const int64_t num_blocks = block_mapping.size(0);
  for (size_t i = 0; i < num_blocks; i++) {
    int64_t src_block_number = block_mapping[i][0].item<int64_t>();
    int64_t dst_block_number = block_mapping[i][1].item<int64_t>();
Woosuk Kwon's avatar
Woosuk Kwon committed
68
69
    int64_t src_offset = src_block_number * block_size_in_bytes;
    int64_t dst_offset = dst_block_number * block_size_in_bytes;
70
71
    cudaMemcpyAsync(dst_ptr + dst_offset, src_ptr + src_offset,
                    block_size_in_bytes, memcpy_type, stream);
Woosuk Kwon's avatar
Woosuk Kwon committed
72
73
  }
}
Woosuk Kwon's avatar
Woosuk Kwon committed
74

Woosuk Kwon's avatar
Woosuk Kwon committed
75
namespace vllm {
76
77

// Grid: (num_layers, num_pairs)
78
79
80
81
82
template <typename scalar_t>
__global__ void copy_blocks_kernel(int64_t* key_cache_ptrs,
                                   int64_t* value_cache_ptrs,
                                   const int64_t* __restrict__ block_mapping,
                                   const int numel_per_block) {
83
84
85
86
  const int layer_idx = blockIdx.x;
  const int pair_idx = blockIdx.y;

  scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]);
87
88
  scalar_t* value_cache =
      reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
89
90
  int64_t src_block_number = block_mapping[2 * pair_idx];
  int64_t dst_block_number = block_mapping[2 * pair_idx + 1];
91

92
93
  const int64_t src_block_offset = src_block_number * numel_per_block;
  const int64_t dst_block_offset = dst_block_number * numel_per_block;
94
  for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
95
96
    int64_t src_offset = src_block_offset + i;
    int64_t dst_offset = dst_block_offset + i;
97
98
99
    key_cache[dst_offset] = key_cache[src_offset];
  }
  for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
100
101
    int64_t src_offset = src_block_offset + i;
    int64_t dst_offset = dst_block_offset + i;
102
103
104
105
    value_cache[dst_offset] = value_cache[src_offset];
  }
}

106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
// Kernel for MLA, which works on a single joint kv_cache
// Grid: (num_layers, num_pairs)
template <typename scalar_t>
__global__ void copy_blocks_mla_kernel(
    int64_t* cache_ptrs, const int64_t* __restrict__ block_mapping,
    const int mem_footprint_per_block) {
  const int layer_idx = blockIdx.x;
  const int pair_idx = blockIdx.y;
  scalar_t* cache = reinterpret_cast<scalar_t*>(cache_ptrs[layer_idx]);
  int64_t src_block = block_mapping[2 * pair_idx];
  int64_t dst_block = block_mapping[2 * pair_idx + 1];
  int64_t src_offset = src_block * mem_footprint_per_block;
  int64_t dst_offset = dst_block * mem_footprint_per_block;
  for (int i = threadIdx.x; i < mem_footprint_per_block; i += blockDim.x) {
    cache[dst_offset + i] = cache[src_offset + i];
  }
}

124
}  // namespace vllm
125

126
127
128
129
130
// Note: the key_caches and value_caches vectors are constant but
// not the Tensors they contain. The vectors need to be const refs
// in order to satisfy pytorch's C++ operator registration code.
void copy_blocks(std::vector<torch::Tensor> const& key_caches,
                 std::vector<torch::Tensor> const& value_caches,
131
                 const torch::Tensor& block_mapping) {
132
133
134
135
136
137
138
  int num_layers = key_caches.size();
  TORCH_CHECK(num_layers == value_caches.size());
  if (num_layers == 0) {
    return;
  }
  torch::Device cache_device = key_caches[0].device();
  TORCH_CHECK(cache_device.is_cuda());
139

140
141
142
143
144
  // Create data structures for the kernel.
  // Create an array of pointers to the key and value caches.
  int64_t key_cache_ptrs[num_layers];
  int64_t value_cache_ptrs[num_layers];
  for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
145
146
147
148
    key_cache_ptrs[layer_idx] =
        reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
    value_cache_ptrs[layer_idx] =
        reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
149
  }
150
151
152

  // block_mapping is a 2D tensor with shape (num_pairs, 2).
  int num_pairs = block_mapping.size(0);
153
154
155

  // Move the data structures to the GPU.
  // NOTE: This synchronizes the CPU and GPU.
156
157
158
159
160
161
  torch::Tensor key_cache_ptrs_tensor =
      torch::from_blob(key_cache_ptrs, {num_layers}, torch::kInt64)
          .to(cache_device);
  torch::Tensor value_cache_ptrs_tensor =
      torch::from_blob(value_cache_ptrs, {num_layers}, torch::kInt64)
          .to(cache_device);
162
163
164
165
166

  // Launch the kernel.
  const int numel_per_block = key_caches[0][0].numel();
  dim3 grid(num_layers, num_pairs);
  dim3 block(std::min(1024, numel_per_block));
167
  const at::cuda::OptionalCUDAGuard device_guard(cache_device);
168
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
169
  VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
170
171
172
173
174
175
      key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
        vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
            key_cache_ptrs_tensor.data_ptr<int64_t>(),
            value_cache_ptrs_tensor.data_ptr<int64_t>(),
            block_mapping.data_ptr<int64_t>(), numel_per_block);
      }));
176
177
}

178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
// copy blocks kernel for MLA (assumes a joint KV-cache)
void copy_blocks_mla(std::vector<torch::Tensor> const& kv_caches,
                     const torch::Tensor& block_mapping) {
  int num_layers = kv_caches.size();
  if (num_layers == 0) {
    return;
  }
  torch::Device cache_device = kv_caches[0].device();
  TORCH_CHECK(cache_device.is_cuda(), "kv_cache must be on CUDA");

  std::vector<int64_t> cache_ptrs(num_layers);
  for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
    cache_ptrs[layer_idx] =
        reinterpret_cast<int64_t>(kv_caches[layer_idx].data_ptr());
  }
  torch::Tensor cache_ptrs_tensor =
      torch::from_blob(cache_ptrs.data(), {num_layers}, torch::kInt64)
          .to(cache_device);

  int num_pairs = block_mapping.size(0);
  // We use the stride instead of numel in case the cache is padded for memory
  // alignment reasons, we assume the blocks data (inclusive of any padding)
  // is contiguous in memory
  int mem_footprint_per_block = kv_caches[0].stride(0);
  dim3 grid(num_layers, num_pairs);
  dim3 block(std::min(1024, mem_footprint_per_block));
  const at::cuda::OptionalCUDAGuard device_guard(cache_device);
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
      kv_caches[0].scalar_type(), "copy_blocks_mla_kernel", ([&] {
        vllm::copy_blocks_mla_kernel<scalar_t><<<grid, block, 0, stream>>>(
            cache_ptrs_tensor.data_ptr<int64_t>(),
            block_mapping.data_ptr<int64_t>(), mem_footprint_per_block);
      }));
}

Woosuk Kwon's avatar
Woosuk Kwon committed
214
namespace vllm {
Woosuk Kwon's avatar
Woosuk Kwon committed
215

216
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
Woosuk Kwon's avatar
Woosuk Kwon committed
217
__global__ void reshape_and_cache_kernel(
218
219
220
221
222
223
224
225
    const scalar_t* __restrict__ key,    // [num_tokens, num_heads, head_size]
    const scalar_t* __restrict__ value,  // [num_tokens, num_heads, head_size]
    cache_t* __restrict__ key_cache,     // [num_blocks, num_heads, head_size/x,
                                         // block_size, x]
    cache_t* __restrict__ value_cache,   // [num_blocks, num_heads, head_size,
                                         // block_size]
    const int64_t* __restrict__ slot_mapping,  // [num_tokens]
    const int key_stride, const int value_stride, const int num_heads,
226
227
    const int head_size, const int block_size, const int x,
    const float* k_scale, const float* v_scale) {
228
229
  const int64_t token_idx = blockIdx.x;
  const int64_t slot_idx = slot_mapping[token_idx];
230
231
232
233
234
  if (slot_idx < 0) {
    // Padding token that should be ignored.
    return;
  }

235
236
  const int64_t block_idx = slot_idx / block_size;
  const int64_t block_offset = slot_idx % block_size;
Woosuk Kwon's avatar
Woosuk Kwon committed
237
238
239

  const int n = num_heads * head_size;
  for (int i = threadIdx.x; i < n; i += blockDim.x) {
240
241
    const int64_t src_key_idx = token_idx * key_stride + i;
    const int64_t src_value_idx = token_idx * value_stride + i;
Woosuk Kwon's avatar
Woosuk Kwon committed
242
243
244
245
246
247

    const int head_idx = i / head_size;
    const int head_offset = i % head_size;
    const int x_idx = head_offset / x;
    const int x_offset = head_offset % x;

248
249
250
251
252
253
254
255
    const int64_t tgt_key_idx =
        block_idx * num_heads * (head_size / x) * block_size * x +
        head_idx * (head_size / x) * block_size * x + x_idx * block_size * x +
        block_offset * x + x_offset;
    const int64_t tgt_value_idx =
        block_idx * num_heads * head_size * block_size +
        head_idx * head_size * block_size + head_offset * block_size +
        block_offset;
256
257
    scalar_t tgt_key = key[src_key_idx];
    scalar_t tgt_value = value[src_value_idx];
258
    if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
259
260
      key_cache[tgt_key_idx] = tgt_key;
      value_cache[tgt_value_idx] = tgt_value;
xiabo's avatar
xiabo committed
261
262
263
264
265
266
267
    } else if constexpr (kv_dt == Fp8KVCacheDataType::kInt8) {
      key_cache[tgt_key_idx] =
          int8::scaled_vec_conversion_int8<cache_t, scalar_t>(tgt_key, 
                                                              *k_scale);
      value_cache[tgt_value_idx] =
          int8::scaled_vec_conversion_int8<cache_t, scalar_t>(tgt_value, 
                                                              *v_scale);
268
    } else {
269
      key_cache[tgt_key_idx] =
270
          fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, *k_scale);
271
      value_cache[tgt_value_idx] =
272
          fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, *v_scale);
273
    }
Woosuk Kwon's avatar
Woosuk Kwon committed
274
275
276
  }
}

zhuwenwen's avatar
zhuwenwen committed
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void reshape_and_cache_kernel_cuda(
    const scalar_t* __restrict__ key,       // [num_tokens, num_heads, head_size]
    const scalar_t* __restrict__ value,     // [num_tokens, num_heads, head_size]
    cache_t* __restrict__ key_cache,        // [num_blocks, num_heads, block_size, head_size]  target layout
    cache_t* __restrict__ value_cache,      // [num_blocks, num_heads, head_size, block_size]  
    const int64_t* __restrict__ slot_mapping,   // [num_tokens]
    const int key_stride, const int value_stride, const int num_heads, 
    const int head_size, const int block_size, int x,    
    const float* k_scale, const float* v_scale) {
  const int64_t token_idx = blockIdx.x;
  const int64_t slot_idx = slot_mapping[token_idx];
  if (slot_idx < 0) {
    // Padding token that should be ignored.
    return;
  }               

  const int64_t block_idx = slot_idx / block_size;
  const int64_t block_offset = slot_idx % block_size;

  const int n = num_heads * head_size;
  for (int i = threadIdx.x; i < n; i += blockDim.x) {
    const int64_t src_key_idx = token_idx * key_stride + i;
    const int64_t src_value_idx = token_idx * value_stride + i;

    const int head_idx  = i / head_size;   
    const int head_offset = i % head_size; 

    // ---------- calculate target index ----------
    // K: [num_blocks, num_heads, block_size, head_size]
    const int64_t tgt_key_idx = 
      block_idx * num_heads * block_size * head_size +
      head_idx * block_size * head_size + block_offset * head_size +
      head_offset;
    // V: [num_blocks, num_heads, head_size, block_size]
    const int64_t tgt_value_idx = 
      block_idx  * num_heads * head_size * block_size +
      head_idx * head_size * block_size + head_offset * block_size +
      block_offset;
    scalar_t tgt_key = key[src_key_idx];
    scalar_t tgt_value = value[src_value_idx];
    if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
      key_cache[tgt_key_idx] = tgt_key;
      value_cache[tgt_value_idx] = tgt_value;
    } else if constexpr (kv_dt == Fp8KVCacheDataType::kInt8) {
      key_cache[tgt_key_idx] =
          int8::scaled_vec_conversion_int8<cache_t, scalar_t>(tgt_key, 
                                                              *k_scale);
      value_cache[tgt_value_idx] =
          int8::scaled_vec_conversion_int8<cache_t, scalar_t>(tgt_value, 
                                                              *v_scale);
    } else {
      key_cache[tgt_key_idx] =
          fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, *k_scale);
      value_cache[tgt_value_idx] =
          fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, *v_scale);
    }
  }
}

337
338
339
340
341
342
343
344
345
346
347
348
349
350
// Used by vectorization_utils to copy/convert one element
template <typename OutT, typename InT, Fp8KVCacheDataType kv_dt>
struct CopyWithScaleOp {
  float scale;

  __device__ __forceinline__ void operator()(OutT& dst, const InT src) const {
    if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
      dst = static_cast<OutT>(src);
    } else {
      dst = fp8::scaled_convert<OutT, InT, kv_dt>(src, scale);
    }
  }
};

351
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
352
__global__ void reshape_and_cache_flash_kernel(
353
354
    const scalar_t* __restrict__ key,    // [num_tokens, num_heads, head_size]
    const scalar_t* __restrict__ value,  // [num_tokens, num_heads, head_size]
355
356
    cache_t* __restrict__ key_cache,     // NHD or HND, shape see comments below
    cache_t* __restrict__ value_cache,   // same above
357
    const int64_t* __restrict__ slot_mapping,  // [num_tokens]
358
359
360
361
    const int64_t block_stride, const int64_t page_stride,
    const int64_t head_stride, const int64_t key_stride,
    const int64_t value_stride, const int num_heads, const int head_size,
    const int block_size, const float* k_scale, const float* v_scale) {
362
363
364
365
366
367
368
369
  const int64_t token_idx = blockIdx.x;
  const int64_t slot_idx = slot_mapping[token_idx];
  // NOTE: slot_idx can be -1 if the token is padded
  if (slot_idx < 0) {
    return;
  }
  const int64_t block_idx = slot_idx / block_size;
  const int64_t block_offset = slot_idx % block_size;
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
  const int n_elems = num_heads * head_size;

  // pointers to the beginning of the source row for this token.
  const scalar_t* __restrict__ key_src = key + token_idx * key_stride;
  const scalar_t* __restrict__ value_src = value + token_idx * value_stride;

  // find the start position inside the kv-cache for this token.
  cache_t* __restrict__ key_dst =
      key_cache + block_idx * block_stride + block_offset * page_stride;
  cache_t* __restrict__ value_dst =
      value_cache + block_idx * block_stride + block_offset * page_stride;

  // this is true for the NHD layout where `head_stride == head_size`
  const bool is_contiguous_heads = (head_stride == head_size);

  float k_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto) ? 0.f : *k_scale;
  float v_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto) ? 0.f : *v_scale;
  constexpr int VEC_SIZE = (sizeof(scalar_t) == 2) ? 8 : 4;
  CopyWithScaleOp<cache_t, scalar_t, kv_dt> k_op{k_scale_val};
  CopyWithScaleOp<cache_t, scalar_t, kv_dt> v_op{v_scale_val};
  if (is_contiguous_heads) {
    // NHD layout
    // kv cache: [num_blocks, block_size, num_heads, head_size]
    vectorize_with_alignment<VEC_SIZE>(key_src, key_dst, n_elems, threadIdx.x,
                                       blockDim.x, k_op);

    vectorize_with_alignment<VEC_SIZE>(value_src, value_dst, n_elems,
                                       threadIdx.x, blockDim.x, v_op);

  } else {
    // HND layout: heads are strided, but each head_size segment is contiguous
    // kv cache: [num_blocks, num_heads, block_size, head_size]
    const int lane = threadIdx.x & 31;     // 0..31 within warp
    const int warp_id = threadIdx.x >> 5;  // warp index within block
    const int warps_per_block = blockDim.x >> 5;

    for (int head = warp_id; head < num_heads; head += warps_per_block) {
      const scalar_t* __restrict__ k_src_h = key_src + head * head_size;
      const scalar_t* __restrict__ v_src_h = value_src + head * head_size;

      cache_t* __restrict__ k_dst_h =
          key_dst + static_cast<int64_t>(head) * head_stride;
      cache_t* __restrict__ v_dst_h =
          value_dst + static_cast<int64_t>(head) * head_stride;

      // within each head, let the 32 threads of the warp perform the vector
      // copy
      vectorize_with_alignment<VEC_SIZE>(k_src_h, k_dst_h, head_size, lane, 32,
                                         k_op);

      vectorize_with_alignment<VEC_SIZE>(v_src_h, v_dst_h, head_size, lane, 32,
                                         v_op);
422
    }
423
424
  }
}
425

426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void write_cache_multi_layers_kernel(
    scalar_t* __restrict__  keys,     // [num_layers, num_tokens, num_heads, head_size]
    scalar_t* __restrict__  values,   // [num_layers, num_tokens, num_heads, head_size]
    int64_t* key_cache_ptrs,     // [num_blocks, num_heads, head_size/x,
                                         // block_size, x]
    int64_t* value_cache_ptrs,   // [num_blocks, num_heads, head_size,
                                         // block_size]
    const int64_t* __restrict__ slot_mapping,  // [num_tokens]
    const int key_stride, const int value_stride,
    const int num_heads, const int head_size, const int block_size, 
    const int x, const int num_tokens) {
  const int layer_idx = blockIdx.x;
  const int token_idx = blockIdx.y;

  const int64_t slot_idx = slot_mapping[token_idx];
  if (slot_idx < 0) {
    // Padding token that should be ignored.
    return;
  }

  cache_t* key_cache = reinterpret_cast<cache_t*>(key_cache_ptrs[layer_idx]);
  cache_t* value_cache =
      reinterpret_cast<cache_t*>(value_cache_ptrs[layer_idx]);

  scalar_t* key = keys + layer_idx * num_tokens * key_stride;
  scalar_t* value = values + layer_idx * num_tokens * value_stride;

  const int64_t block_idx = slot_idx / block_size;
  const int64_t block_offset = slot_idx % block_size;

  const int n = num_heads * head_size;
  for (int i = threadIdx.x; i < n; i += blockDim.x) {
    const int head_idx = i / head_size;
    const int head_offset = i % head_size;
    const int x_idx = head_offset / x;
    const int x_offset = head_offset % x;

    const int64_t tgt_key_idx =
        block_idx * num_heads * (head_size / x) * block_size * x +
        head_idx * (head_size / x) * block_size * x + x_idx * block_size * x +
        block_offset * x + x_offset;
    const int64_t tgt_value_idx =
        block_idx * num_heads * head_size * block_size +
        head_idx * head_size * block_size + head_offset * block_size +
        block_offset;

    const int64_t src_key_idx = token_idx * key_stride + i;
    const int64_t src_value_idx = token_idx * value_stride + i;

    scalar_t tgt_key = key[src_key_idx];
    scalar_t tgt_value = value[src_value_idx];

    if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
      key_cache[tgt_key_idx] = tgt_key;
      value_cache[tgt_value_idx] = tgt_value;
    } else {
      key_cache[tgt_key_idx] =
          fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, 1.0);
      value_cache[tgt_value_idx] =
          fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, 1.0);
    }
  }
}

template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void read_cache_kernel(
    scalar_t* __restrict__  keys,     // [num_layers, num_tokens, num_heads, head_size]
    scalar_t* __restrict__  values,   // [num_layers, num_tokens, num_heads, head_size]
    int64_t* key_cache_ptrs,     // [num_blocks, num_heads, head_size/x,
                                         // block_size, x]
    int64_t* value_cache_ptrs,   // [num_blocks, num_heads, head_size,
                                         // block_size]
    const int64_t* __restrict__ slot_mapping,  // [num_tokens]
    const int key_stride, const int value_stride,
    const int num_heads, const int head_size, const int block_size,
    const int x, const int num_tokens) {
  const int layer_idx = blockIdx.x;
  const int token_idx = blockIdx.y;

  const int64_t slot_idx = slot_mapping[token_idx];
  if (slot_idx < 0) {
    // Padding token that should be ignored.
    return;
  }

  cache_t* key_cache = reinterpret_cast<cache_t*>(key_cache_ptrs[layer_idx]);
  cache_t* value_cache =
      reinterpret_cast<cache_t*>(value_cache_ptrs[layer_idx]);

  scalar_t* key = keys + layer_idx * num_tokens * key_stride;
  scalar_t* value = values + layer_idx * num_tokens * value_stride;

  const int64_t block_idx = slot_idx / block_size;
  const int64_t block_offset = slot_idx % block_size;

  const int n = num_heads * head_size;
  for (int i = threadIdx.x; i < n; i += blockDim.x) {
    const int head_idx = i / head_size;
    const int head_offset = i % head_size;
    const int x_idx = head_offset / x;
    const int x_offset = head_offset % x;

    const int64_t src_key_idx =
        block_idx * num_heads * (head_size / x) * block_size * x +
        head_idx * (head_size / x) * block_size * x + x_idx * block_size * x +
        block_offset * x + x_offset;
    const int64_t src_value_idx =
        block_idx * num_heads * head_size * block_size +
        head_idx * head_size * block_size + head_offset * block_size +
        block_offset;

    const int64_t tgt_key_idx = token_idx * key_stride + i;
    const int64_t tgt_value_idx = token_idx * value_stride + i;
    cache_t tgt_key = key_cache[src_key_idx];
    cache_t tgt_value = value_cache[src_value_idx];

    if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
      key[tgt_key_idx] = tgt_key;
      value[tgt_value_idx] = tgt_value;
    } else {
      key[tgt_key_idx] = fp8::scaled_convert<scalar_t, cache_t, kv_dt>(tgt_key, 1.0);
      value[tgt_value_idx] = fp8::scaled_convert<scalar_t, cache_t, kv_dt>(tgt_value, 1.0);
    }
  }
zhuwenwen's avatar
zhuwenwen committed
551
    }
zhuwenwen's avatar
zhuwenwen committed
552
553


554
555
556
557
558
559
560
561
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void concat_and_cache_mla_kernel(
    const scalar_t* __restrict__ kv_c,  // [num_tokens, kv_lora_rank]
    const scalar_t* __restrict__ k_pe,  // [num_tokens, pe_dim]
    cache_t* __restrict__ kv_cache,  // [num_blocks, block_size, (kv_lora_rank
                                     // + pe_dim)]
    const int64_t* __restrict__ slot_mapping,  // [num_tokens]
    const int block_stride,                    //
562
    const int entry_stride,                    //
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
    const int kv_c_stride,                     //
    const int k_pe_stride,                     //
    const int kv_lora_rank,                    //
    const int pe_dim,                          //
    const int block_size,                      //
    const float* scale                         //
) {
  const int64_t token_idx = blockIdx.x;
  const int64_t slot_idx = slot_mapping[token_idx];
  // NOTE: slot_idx can be -1 if the token is padded
  if (slot_idx < 0) {
    return;
  }
  const int64_t block_idx = slot_idx / block_size;
  const int64_t block_offset = slot_idx % block_size;

  auto copy = [&](const scalar_t* __restrict__ src, cache_t* __restrict__ dst,
                  int src_stride, int dst_stride, int size, int offset) {
    for (int i = threadIdx.x; i < size; i += blockDim.x) {
      const int64_t src_idx = token_idx * src_stride + i;
583
584
      const int64_t dst_idx =
          block_idx * block_stride + block_offset * entry_stride + i + offset;
585
586
587
588
589
590
591
592
593
594
595
      if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
        dst[dst_idx] = src[src_idx];
      } else {
        dst[dst_idx] =
            fp8::scaled_convert<cache_t, scalar_t, kv_dt>(src[src_idx], *scale);
      }
    }
  };

  copy(kv_c, kv_cache, kv_c_stride, block_stride, kv_lora_rank, 0);
  copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank);
596
597
}

598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void concat_and_cache_ds_mla_kernel(
    const scalar_t* __restrict__ kv_c,  // [num_tokens, kv_lora_rank]
    const scalar_t* __restrict__ k_pe,  // [num_tokens, pe_dim]
    cache_t* __restrict__ kv_cache,  // [num_blocks, block_size, (kv_lora_rank
                                     // + pe_dim)]
    const int64_t* __restrict__ slot_mapping,  // [num_tokens]
    const int block_stride,                    //
    const int entry_stride,                    //
    const int kv_c_stride,                     //
    const int k_pe_stride,                     //
    const int kv_lora_rank,                    //
    const int pe_dim,                          //
    const int block_size,                      //
    const float* scale                         //
) {
  const int64_t token_idx = blockIdx.x;
  const int64_t slot_idx = slot_mapping[token_idx];
  // NOTE: slot_idx can be -1 if the token is padded
  if (slot_idx < 0) {
    return;
  }
  const int64_t block_idx = slot_idx / block_size;
  const int64_t block_offset = slot_idx % block_size;
  const int64_t dst_idx_start =
      block_idx * block_stride + block_offset * entry_stride;

  // Create 4 tile scales in shared memory
  __shared__ float smem[20];
  float* shard_abs_max = smem;
  float* tile_scales = smem + 16;

  // For the NoPE part, each tile of 128 elements is handled by 4 warps
  // (128 threads). There are 4 total tiles, so 16 warps (512 threads).
  // The first thread of the first warp in each tile writes the scale
  // value for the tile. The RoPE part (last 64 elements) is handled
  // by another 2 warps (64 threads).
  // So in total, we use 18 warps (576 threads) per block.

  // Cast kv_cache to 16_bit for RoPE values
  scalar_t* kv_cache_16bit =
      reinterpret_cast<scalar_t*>(&kv_cache[dst_idx_start]);

  // The last 64 threads handle the RoPE part
  if (threadIdx.x >= kv_lora_rank) {
    const int8_t pe_idx = threadIdx.x - kv_lora_rank;
    const int64_t src_idx = token_idx * k_pe_stride + pe_idx;
    // RoPE values start after the packed 8-bit NoPE values and the
    // 32-bit scales
    const int64_t dst_idx = kv_lora_rank / 2 + 8 + pe_idx;
    kv_cache_16bit[dst_idx] = k_pe[src_idx];
    return;
  }

  // Determine the scale for each chunk of NoPE
  const int16_t tile_idx = threadIdx.x >> 7;
  const int16_t warp_idx = (threadIdx.x & 127) >> 5;
  const int16_t lane_idx = threadIdx.x & 31;

  // Load the NoPE element for this thread into registers
  const int64_t src_idx = token_idx * kv_c_stride + threadIdx.x;
  const scalar_t src_val = kv_c[src_idx];

  // Warp-level reduction to find the max absolute value in the warp
  float max_abs = fabsf(src_val);
#pragma unroll
  for (int offset = 16; offset > 0; offset /= 2) {
#ifdef USE_ROCM
    max_abs = fmaxf(max_abs, __shfl_down_sync(UINT64_MAX, max_abs, offset));
#else
    max_abs = fmaxf(max_abs, __shfl_down_sync(0xFFFFFFFF, max_abs, offset));
#endif
  }

  // The first lane of each warp in each tile writes the max_abs of this part
  // of the tile to shared memory
  if (lane_idx == 0) {
    shard_abs_max[tile_idx * 4 + warp_idx] = max_abs;
  }
  __syncthreads();

  // The first lane of the first warp in each tile computes the scale for the
  // tile and writes it to shared memory and to kv_cache
  if (warp_idx == 0 && lane_idx == 0) {
    float4 shard_abs_max_vec =
        reinterpret_cast<float4*>(shard_abs_max)[tile_idx];
    float tile_scale = fmaxf(fmaxf(shard_abs_max_vec.x, shard_abs_max_vec.y),
                             fmaxf(shard_abs_max_vec.z, shard_abs_max_vec.w)) /
                       448.f;

    // Avoid division by zero in `scaled_convert`
    tile_scales[tile_idx] = fmaxf(tile_scale, FLT_MIN);
    float* kv_cache_32bit = reinterpret_cast<float*>(&kv_cache[dst_idx_start]);
    const uint64_t dst_idx = kv_lora_rank / 4 + tile_idx;
    kv_cache_32bit[dst_idx] = tile_scales[tile_idx];
  }

  __syncthreads();

  // Now all threads in the block scale and write their element
  const float scale_val = tile_scales[tile_idx];
  const int64_t dst_idx = dst_idx_start + threadIdx.x;
  kv_cache[dst_idx] =
      fp8::scaled_convert<uint8_t, scalar_t, Fp8KVCacheDataType::kFp8E4M3>(
          src_val, scale_val);
}

template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void indexer_k_quant_and_cache_kernel(
    const scalar_t* __restrict__ k,  // [num_tokens, head_dim]
    cache_t* __restrict__ kv_cache,  // [num_blocks, block_size, cache_stride]
    const int64_t* __restrict__ slot_mapping,  // [num_tokens]
    const int head_dim,                        // dimension of each head
    const int quant_block_size,                // quantization block size
    const int cache_block_size,                // cache block size
    const int cache_stride,  // stride for each token in kv_cache
    const bool use_ue8m0     // use ue8m0 scale format
) {
  constexpr int VEC_SIZE = 4;
  const int64_t token_idx = blockIdx.x;
  const int64_t head_dim_idx = (blockIdx.y * blockDim.y * blockDim.x +
                                threadIdx.y * blockDim.x + threadIdx.x) *
                               VEC_SIZE;
  const int64_t slot_idx = slot_mapping[token_idx];
  const int64_t block_idx = slot_idx / cache_block_size;
  const int64_t block_offset = slot_idx % cache_block_size;

  // NOTE: slot_idx can be -1 if the token is padded
  if (slot_idx < 0 || (head_dim_idx >= head_dim)) {
    return;
  }

  float2 k_val = (reinterpret_cast<const float2*>(
      k))[(token_idx * head_dim + head_dim_idx) / VEC_SIZE];
  scalar_t* k_val_ptr = reinterpret_cast<scalar_t*>(&k_val);
  float amax = 0.0f;
  for (int i = 0; i < VEC_SIZE; i++) {
    amax = fmaxf(amax, fabsf(float(k_val_ptr[i])));
  }
737
#ifndef USE_ROCM
738
  __syncwarp();
739
#endif
740
741
742
743
744
745
746
747
748

  // Reduced amax
  for (int mask = 16; mask > 0; mask /= 2) {
#ifdef USE_ROCM
    amax = fmaxf(amax, __shfl_xor_sync(uint64_t(-1), amax, mask));
#else
    amax = fmaxf(amax, __shfl_xor_sync(unsigned(-1), amax, mask));
#endif
  }
749
#ifndef USE_ROCM
750
  __syncwarp();
751
#endif
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
  float scale = fmaxf(amax, 1e-4) / 448.0f;
  if (use_ue8m0) {
    scale = exp2f(ceilf(log2f(scale)));
  }

  const int64_t dst_offset = block_idx * cache_block_size * cache_stride +
                             block_offset * head_dim + head_dim_idx;
  for (int i = 0; i < VEC_SIZE; i++) {
    kv_cache[dst_offset + i] =
        fp8::scaled_convert<cache_t, scalar_t, kv_dt>(k_val_ptr[i], scale);
  }
  if (threadIdx.x == 0) {
    const int64_t dst_scale_idx =
        block_idx * cache_block_size * cache_stride +
        cache_block_size * head_dim +
        (block_offset * head_dim + head_dim_idx) * 4 / quant_block_size;
    reinterpret_cast<float*>(kv_cache)[dst_scale_idx / 4] = scale;
  }
}

772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
template <typename scalar_t, typename cache_t>
__global__ void indexer_k_cache_kernel(
    const scalar_t* __restrict__ k,  // [num_tokens, head_dim]
    cache_t* __restrict__ kv_cache,  // [num_blocks, block_size, cache_stride]
    const int64_t* __restrict__ slot_mapping,  // [num_tokens]
    const int head_dim,                        // dimension of each head
    const int cache_block_size,                // cache block size
    const int cache_stride                     // stride for each token in kv_cache
) {
  constexpr int VEC_SIZE = 4;
  const int64_t token_idx = blockIdx.x;
  const int64_t head_dim_idx = (blockIdx.y * blockDim.y * blockDim.x +
                                threadIdx.y * blockDim.x + threadIdx.x) *
                               VEC_SIZE;
  const int64_t slot_idx = slot_mapping[token_idx];
  const int64_t block_idx = slot_idx / cache_block_size;
  const int64_t block_offset = slot_idx % cache_block_size;

  // NOTE: slot_idx can be -1 if the token is padded
  if (slot_idx < 0 || (head_dim_idx >= head_dim)) {
    return;
  }

  float2 k_val = (reinterpret_cast<const float2*>(
      k))[(token_idx * head_dim + head_dim_idx) / VEC_SIZE];
  scalar_t* k_val_ptr = reinterpret_cast<scalar_t*>(&k_val);

  const int64_t dst_offset = block_idx * cache_block_size * cache_stride +
                             block_offset * head_dim + head_dim_idx;
  for (int i = 0; i < VEC_SIZE; i++) {
zhuwenwen's avatar
zhuwenwen committed
802
803
804
805
806
807
808
809
810
811
812
813
814
    float val = static_cast<float>(k_val_ptr[i]);
    
    if constexpr (std::is_same<cache_t, at::Half>::value ||
                  std::is_same<cache_t, __half>::value) {
      kv_cache[dst_offset + i] = __float2half(val);
    } else if constexpr (std::is_same<cache_t, at::BFloat16>::value ||
                         std::is_same<cache_t, __nv_bfloat16>::value) {
      kv_cache[dst_offset + i] = __float2bfloat16(val);
    } else if constexpr (std::is_same<cache_t, float>::value) {
      kv_cache[dst_offset + i] = val;
    } else {
      kv_cache[dst_offset + i] = static_cast<cache_t>(val);
    }
815
816
  }
}
817
}  // namespace vllm
Woosuk Kwon's avatar
Woosuk Kwon committed
818

819
820
// KV_T is the data type of key and value tensors.
// CACHE_T is the stored data type of kv-cache.
821
// KV_DTYPE is the real data type of kv-cache.
822
823
824
825
826
827
828
829
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_DTYPE)               \
  vllm::reshape_and_cache_kernel<KV_T, CACHE_T, KV_DTYPE>             \
      <<<grid, block, 0, stream>>>(                                   \
          reinterpret_cast<KV_T*>(key.data_ptr()),                    \
          reinterpret_cast<KV_T*>(value.data_ptr()),                  \
          reinterpret_cast<CACHE_T*>(key_cache.data_ptr()),           \
          reinterpret_cast<CACHE_T*>(value_cache.data_ptr()),         \
          slot_mapping.data_ptr<int64_t>(), key_stride, value_stride, \
830
831
832
          num_heads, head_size, block_size, x,                        \
          reinterpret_cast<const float*>(k_scale.data_ptr()),         \
          reinterpret_cast<const float*>(v_scale.data_ptr()));
833

Woosuk Kwon's avatar
Woosuk Kwon committed
834
void reshape_and_cache(
835
836
837
838
839
840
841
    torch::Tensor& key,    // [num_tokens, num_heads, head_size]
    torch::Tensor& value,  // [num_tokens, num_heads, head_size]
    torch::Tensor&
        key_cache,  // [num_blocks, num_heads, head_size/x, block_size, x]
    torch::Tensor&
        value_cache,  // [num_blocks, num_heads, head_size, block_size]
    torch::Tensor& slot_mapping,  // [num_tokens]
842
843
    const std::string& kv_cache_dtype, torch::Tensor& k_scale,
    torch::Tensor& v_scale) {
844
  int num_tokens = slot_mapping.size(0);
Woosuk Kwon's avatar
Woosuk Kwon committed
845
846
847
848
849
850
851
852
853
854
  int num_heads = key.size(1);
  int head_size = key.size(2);
  int block_size = key_cache.size(3);
  int x = key_cache.size(4);

  int key_stride = key.stride(0);
  int value_stride = value.stride(0);

  dim3 grid(num_tokens);
  dim3 block(std::min(num_heads * head_size, 512));
855
  const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
Woosuk Kwon's avatar
Woosuk Kwon committed
856
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
857

858
  DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
859
                             CALL_RESHAPE_AND_CACHE);
Woosuk Kwon's avatar
Woosuk Kwon committed
860
861
}

zhuwenwen's avatar
zhuwenwen committed
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
#define CALL_RESHAPE_AND_CACHE_CUDA(KV_T, CACHE_T, KV_DTYPE)               \
  vllm::reshape_and_cache_kernel_cuda<KV_T, CACHE_T, KV_DTYPE>             \
      <<<grid, block, 0, stream>>>(                                        \
          reinterpret_cast<KV_T*>(key.data_ptr()),                         \
          reinterpret_cast<KV_T*>(value.data_ptr()),                       \
          reinterpret_cast<CACHE_T*>(key_cache.data_ptr()),                \
          reinterpret_cast<CACHE_T*>(value_cache.data_ptr()),              \
          slot_mapping.data_ptr<int64_t>(), key_stride, value_stride,      \
          num_heads, head_size, block_size, 1,                             \
          reinterpret_cast<const float*>(k_scale.data_ptr()),              \
          reinterpret_cast<const float*>(v_scale.data_ptr()));

void reshape_and_cache_cuda(
    torch::Tensor& key,        // [num_tokens, num_heads, head_size]
    torch::Tensor& value,      // [num_tokens, num_heads, head_size]
    torch::Tensor&
        key_cache,  // [num_blocks, num_heads, block_size, head_size]
    torch::Tensor& 
        value_cache, // [num_blocks, num_heads, head_size, block_size]
    torch::Tensor& slot_mapping,  // [num_tokens]
    const std::string& kv_cache_dtype, torch::Tensor& k_scale,
    torch::Tensor& v_scale) {
  
  TORCH_CHECK(key.dim() == 3 && value.dim() == 3,
              "key/value must be [num_tokens, num_heads, head_size]");
  TORCH_CHECK(key_cache.dim() == 4 && value_cache.dim() == 4,
              "cache tensor shape mismatch");
  TORCH_CHECK(key_cache.size(0) == value_cache.size(0) &&
              key_cache.size(1) == value_cache.size(1) &&
              key_cache.size(2) == value_cache.size(3) &&
              key_cache.size(3) == value_cache.size(2),
              "key/value cache dimension mismatch");

  int num_tokens = slot_mapping.size(0);
  int num_heads  = key.size(1);
  int head_size  = key.size(2);
  int block_size = key_cache.size(2);   // k layout: [num_blocks, num_heads, block_size, head_size]

  int key_stride   = key.stride(0);
  int value_stride = value.stride(0);

  dim3 grid(num_tokens);
  dim3 block(std::min(num_heads * head_size, 512));
  const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

  DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
                             CALL_RESHAPE_AND_CACHE_CUDA);
}

912
913
// KV_T is the data type of key and value tensors.
// CACHE_T is the stored data type of kv-cache.
914
// KV_DTYPE is the real data type of kv-cache.
915
916
917
918
919
920
921
922
923
924
#define CALL_RESHAPE_AND_CACHE_FLASH(KV_T, CACHE_T, KV_DTYPE)             \
  vllm::reshape_and_cache_flash_kernel<KV_T, CACHE_T, KV_DTYPE>           \
      <<<grid, block, 0, stream>>>(                                       \
          reinterpret_cast<KV_T*>(key.data_ptr()),                        \
          reinterpret_cast<KV_T*>(value.data_ptr()),                      \
          reinterpret_cast<CACHE_T*>(key_cache.data_ptr()),               \
          reinterpret_cast<CACHE_T*>(value_cache.data_ptr()),             \
          slot_mapping.data_ptr<int64_t>(), block_stride, page_stride,    \
          head_stride, key_stride, value_stride, num_heads, head_size,    \
          block_size, reinterpret_cast<const float*>(k_scale.data_ptr()), \
925
          reinterpret_cast<const float*>(v_scale.data_ptr()));
926

927
void reshape_and_cache_flash(
928
929
930
931
932
    torch::Tensor& key,        // [num_tokens, num_heads, head_size]
    torch::Tensor& value,      // [num_tokens, num_heads, head_size]
    torch::Tensor& key_cache,  // [num_blocks, block_size, num_heads, head_size]
    torch::Tensor&
        value_cache,  // [num_blocks, block_size, num_heads, head_size]
933
    torch::Tensor& slot_mapping,  // [num_tokens] or [num_actual_tokens]
934
935
    const std::string& kv_cache_dtype, torch::Tensor& k_scale,
    torch::Tensor& v_scale) {
936
937
938
939
940
941
942
943
944
945
946
  // NOTE(woosuk): In vLLM V1, key.size(0) can be different from
  // slot_mapping.size(0) because of padding for CUDA graphs.
  // In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because
  // both include padding.
  // In vLLM V1, however, key.size(0) can be larger than slot_mapping.size(0)
  // since key includes padding for CUDA graphs, while slot_mapping does not.
  // In this case, slot_mapping.size(0) represents the actual number of tokens
  // before padding.
  // For compatibility with both cases, we use slot_mapping.size(0) as the
  // number of tokens.
  int num_tokens = slot_mapping.size(0);
947
948
  int num_heads = key.size(1);
  int head_size = key.size(2);
949
  int block_size = key_cache.size(1);
950

951
952
953
954
955
  int64_t key_stride = key.stride(0);
  int64_t value_stride = value.stride(0);
  int64_t block_stride = key_cache.stride(0);
  int64_t page_stride = key_cache.stride(1);
  int64_t head_stride = key_cache.stride(2);
956
  TORCH_CHECK(key_cache.stride(0) == value_cache.stride(0));
957
958
959
960
961

  dim3 grid(num_tokens);
  dim3 block(std::min(num_heads * head_size, 512));
  const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
962
963
964

  DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
                             CALL_RESHAPE_AND_CACHE_FLASH);
965
966
}

967
968
// KV_T is the data type of key and value tensors.
// CACHE_T is the stored data type of kv-cache.
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
// KV_DTYPE is the real data type of kv-cache.
#define CALL_READ_CACHE(KV_T, CACHE_T, KV_DTYPE)               \
  vllm::read_cache_kernel<KV_T, CACHE_T, KV_DTYPE>             \
      <<<grid, block, 0, stream>>>(                                   \
          reinterpret_cast<KV_T*>(keys.data_ptr()),                  \
          reinterpret_cast<KV_T*>(values.data_ptr()),                \
          key_cache_ptrs_tensor.data_ptr<int64_t>(),                  \
          value_cache_ptrs_tensor.data_ptr<int64_t>(),                \
          slot_mapping.data_ptr<int64_t>(), \
          key_stride, value_stride,         \
          num_heads, head_size, block_size, x, num_tokens);

void read_cache(
    torch::Tensor& keys, // [num_layers, seq_len, num_heads, head_size]
    torch::Tensor& values, // [num_layers, seq_len, num_heads, head_size]
    std::vector<torch::Tensor> const& key_caches, // [num_blocks, num_heads, head_size/x, block_size, x]
    std::vector<torch::Tensor> const& value_caches, // [num_blocks, num_heads, head_size, block_size]
    torch::Tensor& slot_mapping,  // [num_tokens]
    const std::string& kv_cache_dtype) {
  int num_layers = key_caches.size();
  TORCH_CHECK(num_layers == value_caches.size());
  if (num_layers == 0) {
    return;
  }

  torch::Device cache_device = key_caches[0].device();
  TORCH_CHECK(cache_device.is_cuda());

  // Create data structures for the kernel.
  // Create an array of pointers to the key and value and caches.
  int64_t key_cache_ptrs[num_layers];
  int64_t value_cache_ptrs[num_layers];
  for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
    key_cache_ptrs[layer_idx] =
        reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
    value_cache_ptrs[layer_idx] =
        reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
  }

  int num_tokens = keys.size(1);
  auto kv_dtype =  keys.dtype();
  torch::Tensor key_cache = key_caches[0];
  torch::Tensor value_cache = value_caches[0];

  int key_stride = keys.stride(1);
  int value_stride = values.stride(1);

  int num_heads = value_cache.size(1);
  int head_size = value_cache.size(2);
  int block_size = key_cache.size(3);
  int x = key_cache.size(4);

  // Move the data structures to the GPU.
  // NOTE: This synchronizes the CPU and GPU.
  torch::Tensor key_cache_ptrs_tensor =
      torch::from_blob(key_cache_ptrs, {num_layers}, torch::kInt64)
          .to(cache_device);
  torch::Tensor value_cache_ptrs_tensor =
      torch::from_blob(value_cache_ptrs, {num_layers}, torch::kInt64)
          .to(cache_device);

  dim3 grid(num_layers, num_tokens);
  dim3 block(std::min(num_heads * head_size, 512));
  const at::cuda::OptionalCUDAGuard device_guard(device_of(slot_mapping));
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

  DISPATCH_BY_KV_CACHE_DTYPE(kv_dtype, kv_cache_dtype,
                             CALL_READ_CACHE);
}

// KV_T is the stored data type of kv-cache.
// CACHE_T is the data type of key and value tensors.
// KV_DTYPE is the real data type of kv-cache.
#define CALL_WRITE_CACHE_MULTI_LAYERS(KV_T, CACHE_T, KV_DTYPE)               \
  vllm::write_cache_multi_layers_kernel<KV_T, CACHE_T, KV_DTYPE>             \
      <<<grid, block, 0, stream>>>(                                   \
          reinterpret_cast<KV_T*>(keys.data_ptr()),                  \
          reinterpret_cast<KV_T*>(values.data_ptr()),                \
          key_cache_ptrs_tensor.data_ptr<int64_t>(),                  \
          value_cache_ptrs_tensor.data_ptr<int64_t>(),                \
          slot_mapping.data_ptr<int64_t>(), \
          key_stride, value_stride,         \
          num_heads, head_size, block_size, x, num_tokens);

void write_cache_multi_layers(
    torch::Tensor& keys, // [num_layers, seq_len, num_heads, head_size]
    torch::Tensor& values, // [num_layers, seq_len, num_heads, head_size]
    std::vector<torch::Tensor> const& key_caches, // [num_blocks, num_heads, head_size/x, block_size, x]
    std::vector<torch::Tensor> const& value_caches, // [num_blocks, num_heads, head_size, block_size]
    torch::Tensor& slot_mapping,  // [num_tokens]
    const std::string& kv_cache_dtype) {
  int num_layers = key_caches.size();
  TORCH_CHECK(num_layers == value_caches.size());
  if (num_layers == 0) {
    return;
  }

  torch::Device cache_device = key_caches[0].device();
  TORCH_CHECK(cache_device.is_cuda());

  // Create data structures for the kernel.
  // Create an array of pointers to the key and value and caches.
  int64_t key_cache_ptrs[num_layers];
  int64_t value_cache_ptrs[num_layers];
  for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
    key_cache_ptrs[layer_idx] =
        reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
    value_cache_ptrs[layer_idx] =
        reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
  }

  auto kv_dtype =  keys.dtype();
  int num_tokens = keys.size(1);
  torch::Tensor key_cache = key_caches[0];
  torch::Tensor value_cache = value_caches[0];

  int key_stride = keys.stride(1);
  int value_stride = values.stride(1);

  int num_heads = value_cache.size(1);
  int head_size = value_cache.size(2);
  int block_size = key_cache.size(3);
  int x = key_cache.size(4);

  // Move the data structures to the GPU.
  // NOTE: This synchronizes the CPU and GPU.
  torch::Tensor key_cache_ptrs_tensor =
      torch::from_blob(key_cache_ptrs, {num_layers}, torch::kInt64)
          .to(cache_device);
  torch::Tensor value_cache_ptrs_tensor =
      torch::from_blob(value_cache_ptrs, {num_layers}, torch::kInt64)
          .to(cache_device);

  dim3 grid(num_layers, num_tokens);
  dim3 block(std::min(num_heads * head_size, 512));
  const at::cuda::OptionalCUDAGuard device_guard(device_of(slot_mapping));
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

  DISPATCH_BY_KV_CACHE_DTYPE(kv_dtype, kv_cache_dtype,
                             CALL_WRITE_CACHE_MULTI_LAYERS);
}

1111
1112
1113
1114
1115
1116
1117
1118
#define CALL_CONCAT_AND_CACHE_MLA(KV_T, CACHE_T, KV_DTYPE)              \
  vllm::concat_and_cache_mla_kernel<KV_T, CACHE_T, KV_DTYPE>            \
      <<<grid, block, 0, stream>>>(                                     \
          reinterpret_cast<KV_T*>(kv_c.data_ptr()),                     \
          reinterpret_cast<KV_T*>(k_pe.data_ptr()),                     \
          reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()),              \
          slot_mapping.data_ptr<int64_t>(), block_stride, entry_stride, \
          kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size,   \
1119
1120
          reinterpret_cast<const float*>(scale.data_ptr()));

1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
// KV_T is the data type of key and value tensors.
// CACHE_T is the stored data type of kv-cache.
#define CALL_CONCAT_AND_CACHE_DS_MLA(KV_T, CACHE_T, KV_DTYPE)           \
  vllm::concat_and_cache_ds_mla_kernel<KV_T, CACHE_T, KV_DTYPE>         \
      <<<grid, block, 0, stream>>>(                                     \
          reinterpret_cast<KV_T*>(kv_c.data_ptr()),                     \
          reinterpret_cast<KV_T*>(k_pe.data_ptr()),                     \
          reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()),              \
          slot_mapping.data_ptr<int64_t>(), block_stride, entry_stride, \
          kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size,   \
          reinterpret_cast<const float*>(scale.data_ptr()));

1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
void concat_and_cache_mla(
    torch::Tensor& kv_c,          // [num_tokens, kv_lora_rank]
    torch::Tensor& k_pe,          // [num_tokens, pe_dim]
    torch::Tensor& kv_cache,      // [num_blocks, block_size, (kv_lora_rank +
                                  // pe_dim)]
    torch::Tensor& slot_mapping,  // [num_tokens] or [num_actual_tokens]
    const std::string& kv_cache_dtype, torch::Tensor& scale) {
  // NOTE(woosuk): In vLLM V1, key.size(0) can be different from
  // slot_mapping.size(0) because of padding for CUDA graphs.
  // In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because
  // both include padding.
  // In vLLM V1, however, key.size(0) can be larger than slot_mapping.size(0)
  // since key includes padding for CUDA graphs, while slot_mapping does not.
  // In this case, slot_mapping.size(0) represents the actual number of tokens
  // before padding.
  // For compatibility with both cases, we use slot_mapping.size(0) as the
  // number of tokens.
  int num_tokens = slot_mapping.size(0);
  int kv_lora_rank = kv_c.size(1);
  int pe_dim = k_pe.size(1);
  int block_size = kv_cache.size(1);

1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
  if (kv_cache_dtype == "fp8_ds_mla") {
    TORCH_CHECK(kv_lora_rank == 512, "kv_lora_rank must be 512 for fp8_ds_mla");
    TORCH_CHECK(pe_dim == 64, "pe_dim must be 64 for fp8_ds_mla");
    TORCH_CHECK(kv_cache.size(2) == 656 / kv_cache.itemsize(),
                "kv_cache.size(2) must be 656 bytes for fp8_ds_mla");
    TORCH_CHECK(kv_c.itemsize() == 2,
                "kv_c.itemsize() must be 2 for fp8_ds_mla");
    TORCH_CHECK(k_pe.itemsize() == 2,
                "k_pe.itemsize() must be 2 for fp8_ds_mla");
  } else {
    TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim);
  }
1167
1168
1169
1170

  int kv_c_stride = kv_c.stride(0);
  int k_pe_stride = k_pe.stride(0);
  int block_stride = kv_cache.stride(0);
1171
  int entry_stride = kv_cache.stride(1);
1172
1173
1174
1175

  const at::cuda::OptionalCUDAGuard device_guard(device_of(kv_c));
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
  if (kv_cache_dtype == "fp8_ds_mla") {
    dim3 grid(num_tokens);
    // For the NoPE part, each tile of 128 elements is handled by 4 warps
    // (128 threads). There are 4 total tiles, so 16 warps (512 threads).
    // The first thread of the first warp in each tile writes the scale
    // value for the tile. The RoPE part (last 64 elements) is handled
    // by another 2 warps (64 threads).
    // So in total, we use 18 warps (576 threads) per block.
    dim3 block(576);
    DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
                               CALL_CONCAT_AND_CACHE_DS_MLA);
  } else {
    dim3 grid(num_tokens);
    dim3 block(std::min(kv_lora_rank, 512));
    DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
                               CALL_CONCAT_AND_CACHE_MLA);
  }
1193
1194
}

Woosuk Kwon's avatar
Woosuk Kwon committed
1195
namespace vllm {
Woosuk Kwon's avatar
Woosuk Kwon committed
1196

1197
1198
1199
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
__global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache,
                                   Tout* __restrict__ dst_cache,
1200
                                   const float scale,
1201
                                   const int64_t block_stride) {
1202
1203
1204
  const int64_t block_idx = blockIdx.x;
  for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
    int64_t idx = block_idx * block_stride + i;
1205
    dst_cache[idx] =
1206
        fp8::scaled_convert<Tout, Tin, kv_dt>(src_cache[idx], scale);
1207
1208
1209
  }
}

1210
}  // namespace vllm
1211

1212
1213
1214
#define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE)                                \
  vllm::convert_fp8_kernel<Tout, Tin, KV_DTYPE><<<grid, block, 0, stream>>>( \
      reinterpret_cast<Tin*>(src_cache.data_ptr()),                          \
1215
      reinterpret_cast<Tout*>(dst_cache.data_ptr()), scale, block_stride);
1216

1217
// Only for testing.
1218
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
1219
                 const double scale, const std::string& kv_cache_dtype) {
1220
1221
1222
1223
  torch::Device src_device = src_cache.device();
  torch::Device dst_device = dst_cache.device();
  TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU")
  TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU")
1224
1225
  TORCH_CHECK(src_device.index() == dst_device.index(),
              "src and dst must be on the same GPU");
1226
1227
  at::cuda::OptionalCUDAGuard device_guard(src_device);

1228
1229
1230
1231
1232
1233
1234
  int64_t num_blocks = src_cache.size(0);
  int64_t block_stride = src_cache.stride(0);

  dim3 grid(num_blocks);
  dim3 block(std::min(block_stride, int64_t(512)));
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
  if (kv_cache_dtype == "auto") {
    if (src_cache.dtype() == at::ScalarType::Float) {
      CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kAuto);
    } else if (src_cache.dtype() == at::ScalarType::Half) {
      CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto);
    } else if (src_cache.dtype() == at::ScalarType::BFloat16) {
      CALL_CONVERT_FP8(uint8_t, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto);
    } else if (dst_cache.dtype() == at::ScalarType::Float) {
      CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kAuto);
    } else if (dst_cache.dtype() == at::ScalarType::Half) {
      CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kAuto);
    } else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
      CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kAuto);
    }
  } else if (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3") {
    if (src_cache.dtype() == at::ScalarType::Float) {
      CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kFp8E4M3);
    } else if (src_cache.dtype() == at::ScalarType::Half) {
      CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
    } else if (src_cache.dtype() == at::ScalarType::BFloat16) {
1255
1256
      CALL_CONVERT_FP8(uint8_t, __nv_bfloat16,
                       vllm::Fp8KVCacheDataType::kFp8E4M3);
1257
1258
1259
1260
1261
    } else if (dst_cache.dtype() == at::ScalarType::Float) {
      CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
    } else if (dst_cache.dtype() == at::ScalarType::Half) {
      CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
    } else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
1262
1263
      CALL_CONVERT_FP8(__nv_bfloat16, uint8_t,
                       vllm::Fp8KVCacheDataType::kFp8E4M3);
1264
1265
1266
    }
  } else {
    TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype);
1267
1268
  }
}
1269
1270
1271
1272

namespace vllm {

// grid is launched with dimensions (batch, num_splits)
1273
1274
1275
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void gather_and_maybe_dequant_cache(
    const cache_t* __restrict__ src_cache,    // [NUM_BLOCKS, BLOCK_SIZE,
1276
1277
1278
1279
1280
1281
1282
                                              // ENTRIES...]
    scalar_t* __restrict__ dst,               // [TOT_TOKENS, ENTRIES...]
    const int32_t* __restrict__ block_table,  // [BATCH, BLOCK_INDICES]
    const int32_t* __restrict__ cu_seq_lens,  // [BATCH+1]
    const int32_t block_size, const int32_t entry_size,
    const int64_t block_table_stride, const int64_t cache_block_stride,
    const int64_t cache_entry_stride, const int64_t dst_entry_stride,
1283
    const float* __restrict__ scale,
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
    const int32_t* __restrict__ seq_starts) {  // Optional: starting offsets per
                                               // batch

  const int64_t bid = blockIdx.x;  // Batch ID
  const int32_t num_splits = gridDim.y;
  const int32_t split = blockIdx.y;
  const int32_t seq_start = cu_seq_lens[bid];
  const int32_t seq_end = cu_seq_lens[bid + 1];
  const int32_t seq_len = seq_end - seq_start;
  const int32_t tot_blocks = cuda_utils::ceil_div(seq_len, block_size);
  const int32_t split_blocks = cuda_utils::ceil_div(tot_blocks, num_splits);

  const int32_t split_start = split * split_blocks;
  const int32_t split_end = min((split + 1) * split_blocks, tot_blocks);

  const bool is_active_split = (split_start < tot_blocks);
  const bool is_last_split = (split_end == tot_blocks);

  if (!is_active_split) return;

  int32_t full_blocks_end = split_end;
  int32_t partial_block_size = 0;

  // Adjust the pointer for the block_table for this batch.
  // If seq_starts is provided, compute an offset based on (seq_starts[bid] /
  // page_size)
  const int32_t batch_offset = bid * block_table_stride;
  int32_t offset = 0;
  if (seq_starts != nullptr) {
    offset = seq_starts[bid] / block_size;
  }
  const int32_t* batch_block_table = block_table + batch_offset + offset;

  // Adjust dst pointer based on the cumulative sequence lengths.
  dst += seq_start * dst_entry_stride;

  if (is_last_split) {
    partial_block_size = seq_len % block_size;
    if (partial_block_size) full_blocks_end -= 1;
  }

1325
  auto copy_entry = [&](const cache_t* __restrict__ _src,
1326
                        scalar_t* __restrict__ _dst) {
1327
1328
1329
1330
1331
1332
1333
1334
    for (int i = threadIdx.x; i < entry_size; i += blockDim.x) {
      if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
        _dst[i] = static_cast<scalar_t>(_src[i]);
      } else {
        _dst[i] =
            fp8::scaled_convert<scalar_t, cache_t, kv_dt>(_src[i], *scale);
      }
    }
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
  };

  for (int pid = split_start; pid < full_blocks_end; ++pid) {
    auto block_id = batch_block_table[pid];
    auto block_start_ptr = src_cache + block_id * cache_block_stride;
    auto block_dst_ptr = dst + pid * block_size * dst_entry_stride;
    for (int eid = 0; eid < block_size; ++eid) {
      copy_entry(block_start_ptr + eid * cache_entry_stride,
                 block_dst_ptr + eid * dst_entry_stride);
    }
  }

  if (partial_block_size) {
    auto block_id = batch_block_table[full_blocks_end];
    auto block_start_ptr = src_cache + block_id * cache_block_stride;
    auto block_dst_ptr = dst + full_blocks_end * block_size * dst_entry_stride;
    for (int eid = 0; eid < partial_block_size; ++eid) {
      copy_entry(block_start_ptr + eid * cache_entry_stride,
                 block_dst_ptr + eid * dst_entry_stride);
    }
  }
}

}  // namespace vllm

// Macro to dispatch the kernel based on the data type.
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
// SCALAR_T is the data type of the destination tensor.
// CACHE_T is the stored data type of kv-cache.
// KV_DTYPE is the real data type of kv-cache.
#define CALL_GATHER_CACHE(SCALAR_T, CACHE_T, KV_DTYPE)                      \
  vllm::gather_and_maybe_dequant_cache<SCALAR_T, CACHE_T, KV_DTYPE>         \
      <<<grid, block, 0, stream>>>(                                         \
          reinterpret_cast<CACHE_T*>(src_cache.data_ptr()),                 \
          reinterpret_cast<SCALAR_T*>(dst.data_ptr()),                      \
          block_table.data_ptr<int32_t>(), cu_seq_lens.data_ptr<int32_t>(), \
          block_size, entry_size, block_table_stride, cache_block_stride,   \
          cache_entry_stride, dst_entry_stride,                             \
          reinterpret_cast<const float*>(scale.data_ptr()), seq_starts_ptr);
1373
1374
1375
1376
1377
1378

// Gather sequences from the cache into the destination tensor.
//  - cu_seq_lens contains the cumulative sequence lengths for each batch
//  - block_table contains the cache block indices for each sequence
//  - Optionally, seq_starts (if provided) offsets the starting block index by
//  (seq_starts[bid] / page_size)
1379
void gather_and_maybe_dequant_cache(
1380
1381
1382
1383
    torch::Tensor const& src_cache,    // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
    torch::Tensor const& dst,          // [TOT_TOKENS, ENTRIES...]
    torch::Tensor const& block_table,  // [BATCH, BLOCK_INDICES]
    torch::Tensor const& cu_seq_lens,  // [BATCH+1]
1384
1385
    int64_t batch_size, const std::string& kv_cache_dtype,
    torch::Tensor const& scale,
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
    std::optional<torch::Tensor> seq_starts = std::nullopt) {
  at::cuda::OptionalCUDAGuard device_guard(src_cache.device());
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

  int32_t block_size = src_cache.size(1);
  int32_t entry_size = src_cache.flatten(2, -1).size(2);

  TORCH_CHECK(block_table.dtype() == torch::kInt32,
              "block_table must be int32");
  TORCH_CHECK(cu_seq_lens.dtype() == torch::kInt32,
              "cu_seq_lens must be int32");
  if (seq_starts.has_value()) {
    TORCH_CHECK(seq_starts.value().dtype() == torch::kInt32,
                "seq_starts must be int32");
  }

  TORCH_CHECK(src_cache.device() == dst.device(),
              "src_cache and dst must be on the same device");
  TORCH_CHECK(src_cache.device() == block_table.device(),
              "src_cache and block_table must be on the same device");
  TORCH_CHECK(src_cache.device() == cu_seq_lens.device(),
              "src_cache and cu_seq_lens must be on the same device");
  if (seq_starts.has_value()) {
    TORCH_CHECK(src_cache.device() == seq_starts.value().device(),
                "src_cache and seq_starts must be on the same device");
  }

  int64_t block_table_stride = block_table.stride(0);
  int64_t cache_block_stride = src_cache.stride(0);
  int64_t cache_entry_stride = src_cache.stride(1);
  int64_t dst_entry_stride = dst.stride(0);

  // Decide on the number of splits based on the batch size.
  int num_splits = batch_size > 128 ? 2 : batch_size > 64 ? 4 : 16;
  dim3 grid(batch_size, num_splits);
  dim3 block(1024);

  const int32_t* seq_starts_ptr =
      seq_starts.has_value() ? seq_starts.value().data_ptr<int32_t>() : nullptr;

1426
  DISPATCH_BY_KV_CACHE_DTYPE(dst.dtype(), kv_cache_dtype, CALL_GATHER_CACHE);
1427
}
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498

namespace vllm {
template <typename scalar_t>
// Note(hc): The cp_gather_cache allows seq_starts to no longer be divisible by
// block_size.
__global__ void cp_gather_cache(
    const scalar_t* __restrict__ src_cache,   // [NUM_BLOCKS, BLOCK_SIZE,
                                              // ENTRY_SIZE]
    scalar_t* __restrict__ dst,               // [TOT_TOKENS, ENTRY_SIZE]
    const int32_t* __restrict__ block_table,  // [BATCH, BLOCK_INDICES]
    const int32_t* __restrict__ cu_seq_lens,  // [BATCH+1]
    const int32_t block_size, const int32_t entry_size,
    const int64_t block_table_stride, const int64_t cache_block_stride,
    const int64_t cache_entry_stride, const int64_t dst_entry_stride,
    const int32_t* __restrict__ seq_starts  // Optional: starting offsets per
                                            // batch
) {
  const int64_t bid = blockIdx.x;  // Batch ID
  const int32_t num_splits = gridDim.y;
  const int32_t split = blockIdx.y;
  const int32_t seq_start = cu_seq_lens[bid];
  const int32_t seq_end = cu_seq_lens[bid + 1];
  const int32_t seq_len = seq_end - seq_start;
  const int32_t tot_slots = seq_len;
  const int32_t split_slots = cuda_utils::ceil_div(tot_slots, num_splits);

  const int32_t split_start = split * split_slots;
  const int32_t split_end = min((split + 1) * split_slots, tot_slots);

  const bool is_active_split = (split_start < tot_slots);

  if (!is_active_split) return;

  // Adjust the pointer for the block_table for this batch.
  // If seq_starts is provided, compute an offset based on it
  const int32_t batch_offset = bid * block_table_stride;
  int32_t offset = split_start;
  if (seq_starts != nullptr) {
    offset += seq_starts[bid];
  }
  int32_t offset_div = offset / block_size;
  offset = offset % block_size;
  const int32_t* batch_block_table = block_table + batch_offset;

  // Adjust dst pointer based on the cumulative sequence lengths.
  dst += seq_start * dst_entry_stride;

  auto copy_entry = [&](const scalar_t* __restrict__ _src,
                        scalar_t* __restrict__ _dst) {
    for (int i = threadIdx.x; i < entry_size; i += blockDim.x)
      _dst[i] = _src[i];
  };

  for (int pid = split_start; pid < split_end; ++pid) {
    auto block_id = batch_block_table[offset_div];
    auto block_start_ptr = src_cache + block_id * cache_block_stride;
    auto block_dst_ptr = dst + pid * dst_entry_stride;
    copy_entry(block_start_ptr + offset * cache_entry_stride, block_dst_ptr);
    offset += 1;
    // bump to next block
    if (offset == block_size) {
      offset_div += 1;
      offset = 0;
    }
  }
}
}  // namespace vllm

// Macro to dispatch the kernel based on the data type.
#define CALL_CP_GATHER_CACHE(CPY_DTYPE)                                 \
  vllm::cp_gather_cache<CPY_DTYPE><<<grid, block, 0, stream>>>(         \
1499
1500
1501
1502
1503
1504
1505
1506
1507
      reinterpret_cast<CPY_DTYPE*>(src_cache.data_ptr()),               \
      reinterpret_cast<CPY_DTYPE*>(dst.data_ptr()),                     \
      block_table.data_ptr<int32_t>(), cu_seq_lens.data_ptr<int32_t>(), \
      block_size, entry_size, block_table_stride, cache_block_stride,   \
      cache_entry_stride, dst_entry_stride, seq_starts_ptr);

// Gather sequences from the cache into the destination tensor.
//  - cu_seq_lens contains the cumulative sequence lengths for each batch
//  - block_table contains the cache block indices for each sequence
1508
1509
1510
//  - Optionally, seq_starts (if provided) offsets the starting slot index by
//  seq_starts[bid]
void cp_gather_cache(
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
    torch::Tensor const& src_cache,    // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
    torch::Tensor const& dst,          // [TOT_TOKENS, ENTRIES...]
    torch::Tensor const& block_table,  // [BATCH, BLOCK_INDICES]
    torch::Tensor const& cu_seq_lens,  // [BATCH+1]
    int64_t batch_size,
    std::optional<torch::Tensor> seq_starts = std::nullopt) {
  at::cuda::OptionalCUDAGuard device_guard(src_cache.device());
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

  int32_t block_size = src_cache.size(1);
  int32_t entry_size = src_cache.flatten(2, -1).size(2);

  TORCH_CHECK(block_table.dtype() == torch::kInt32,
              "block_table must be int32");
  TORCH_CHECK(cu_seq_lens.dtype() == torch::kInt32,
              "cu_seq_lens must be int32");
  if (seq_starts.has_value()) {
    TORCH_CHECK(seq_starts.value().dtype() == torch::kInt32,
                "seq_starts must be int32");
  }

  TORCH_CHECK(src_cache.device() == dst.device(),
              "src_cache and dst must be on the same device");
  TORCH_CHECK(src_cache.device() == block_table.device(),
              "src_cache and block_table must be on the same device");
  TORCH_CHECK(src_cache.device() == cu_seq_lens.device(),
              "src_cache and cu_seq_lens must be on the same device");
  if (seq_starts.has_value()) {
    TORCH_CHECK(src_cache.device() == seq_starts.value().device(),
                "src_cache and seq_starts must be on the same device");
  }

  int64_t block_table_stride = block_table.stride(0);
  int64_t cache_block_stride = src_cache.stride(0);
  int64_t cache_entry_stride = src_cache.stride(1);
  int64_t dst_entry_stride = dst.stride(0);

  // Decide on the number of splits based on the batch size.
  int num_splits = batch_size > 128 ? 2 : batch_size > 64 ? 4 : 16;
  dim3 grid(batch_size, num_splits);
  dim3 block(1024);

  TORCH_CHECK(src_cache.dtype() == dst.dtype(),
              "src_cache and dst must have the same dtype");

  const int dtype_bits = src_cache.element_size() * 8;
  const int32_t* seq_starts_ptr =
      seq_starts.has_value() ? seq_starts.value().data_ptr<int32_t>() : nullptr;

  if (dtype_bits == 32) {
1561
    CALL_CP_GATHER_CACHE(uint32_t);
1562
  } else if (dtype_bits == 16) {
1563
    CALL_CP_GATHER_CACHE(uint16_t);
1564
  } else if (dtype_bits == 8) {
1565
    CALL_CP_GATHER_CACHE(uint8_t);
1566
1567
1568
1569
  } else {
    TORCH_CHECK(false, "Unsupported data type width: ", dtype_bits);
  }
}
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607

// Macro to dispatch the kernel based on the data type.
#define CALL_INDEXER_K_QUANT_AND_CACHE(KV_T, CACHE_T, KV_DTYPE)         \
  vllm::indexer_k_quant_and_cache_kernel<KV_T, CACHE_T, KV_DTYPE>       \
      <<<grid, block, 0, stream>>>(                                     \
          reinterpret_cast<KV_T*>(k.data_ptr()),                        \
          reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()),              \
          slot_mapping.data_ptr<int64_t>(), head_dim, quant_block_size, \
          cache_block_size, cache_stride, use_ue8m0);

void indexer_k_quant_and_cache(
    torch::Tensor& k,             // [num_tokens, head_dim]
    torch::Tensor& kv_cache,      // [num_blocks, block_size, cache_stride]
    torch::Tensor& slot_mapping,  // [num_tokens]
    int64_t quant_block_size,     // quantization block size
    const std::string& scale_fmt) {
  int num_tokens = k.size(0);
  int head_dim = k.size(1);
  int cache_block_size = kv_cache.size(1);
  int cache_stride = kv_cache.size(2);
  bool use_ue8m0 = scale_fmt == "ue8m0";

  TORCH_CHECK(k.device() == kv_cache.device(),
              "k and kv_cache must be on the same device");
  TORCH_CHECK(k.device() == slot_mapping.device(),
              "k and slot_mapping must be on the same device");
  TORCH_CHECK(head_dim % quant_block_size == 0,
              "head_dim must be divisible by quant_block_size");

  constexpr int vec_size = 4;
  dim3 grid(num_tokens, (head_dim + quant_block_size * vec_size - 1) /
                            (quant_block_size * vec_size));
  dim3 block(32, vec_size);
  const at::cuda::OptionalCUDAGuard device_guard(device_of(k));
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

  DISPATCH_BY_KV_CACHE_DTYPE(k.dtype(), "fp8_e4m3",
                             CALL_INDEXER_K_QUANT_AND_CACHE);
1608
}
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640

// Macro to dispatch the kernel based on the data type.
#define CALL_INDEXER_K_CACHE(KV_T, CACHE_T)                            \
  vllm::indexer_k_cache_kernel<KV_T, CACHE_T>                          \
      <<<grid, block, 0, stream>>>(                                    \
          reinterpret_cast<KV_T*>(k.data_ptr()),                       \
          reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()),             \
          slot_mapping.data_ptr<int64_t>(), head_dim,                  \
          cache_block_size, cache_stride);

void indexer_k_cache(
    torch::Tensor& k,             // [num_tokens, head_dim]
    torch::Tensor& kv_cache,      // [num_blocks, block_size, cache_stride]
    torch::Tensor& slot_mapping,  // [num_tokens]
    const std::string& scale_fmt) { 
  int num_tokens = k.size(0);
  int head_dim = k.size(1);
  int cache_block_size = kv_cache.size(1);
  int cache_stride = kv_cache.size(2);
  bool use_ue8m0 = scale_fmt == "ue8m0";  

  TORCH_CHECK(k.device() == kv_cache.device(),
              "k and kv_cache must be on the same device");
  TORCH_CHECK(k.device() == slot_mapping.device(),
              "k and slot_mapping must be on the same device");

  constexpr int vec_size = 4;
  dim3 grid(num_tokens, (head_dim + vec_size - 1) / vec_size);
  dim3 block(32, vec_size);
  const at::cuda::OptionalCUDAGuard device_guard(device_of(k));
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

zhuwenwen's avatar
zhuwenwen committed
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
      k.scalar_type(), "indexer_k_cache_k", ([&] {
        using k_t = scalar_t;
        if (kv_cache.scalar_type() == at::ScalarType::Float) {
          vllm::indexer_k_cache_kernel<k_t, float>
              <<<grid, block, 0, stream>>>(
                  k.data_ptr<k_t>(),
                  kv_cache.data_ptr<float>(),
                  slot_mapping.data_ptr<int64_t>(),
                  head_dim,
                  cache_block_size,
                  cache_stride);
        } else if (kv_cache.scalar_type() == at::ScalarType::Half) {
          vllm::indexer_k_cache_kernel<k_t, at::Half>
              <<<grid, block, 0, stream>>>(
                  k.data_ptr<k_t>(),
                  kv_cache.data_ptr<at::Half>(),
                  slot_mapping.data_ptr<int64_t>(),
                  head_dim,
                  cache_block_size,
                  cache_stride);
        } else {
          TORCH_CHECK(false, "Unsupported kv_cache dtype: ", kv_cache.dtype());
        }
      }));
1666
}