cache_kernels.cu 16.5 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
3
#include <c10/cuda/CUDAGuard.h>
Woosuk Kwon's avatar
Woosuk Kwon committed
4

5
#include "cuda_compat.h"
6
#include "dispatch_utils.h"
7
8
9
10
11

#ifdef USE_ROCM
#include "quantization/fp8/amd/quant_utils.cuh"
#else
#include "quantization/fp8/nvidia/quant_utils.cuh"
12
#endif
13

Woosuk Kwon's avatar
Woosuk Kwon committed
14
#include <algorithm>
Woosuk Kwon's avatar
Woosuk Kwon committed
15
16
#include <cassert>
#include <map>
17
#include <vector>
Woosuk Kwon's avatar
Woosuk Kwon committed
18

19
20
21
22
23
#ifdef USE_ROCM
  #include <hip/hip_bf16.h>
  typedef __hip_bfloat16 __nv_bfloat16;
#endif

24
void swap_blocks(
Woosuk Kwon's avatar
Woosuk Kwon committed
25
26
  torch::Tensor& src,
  torch::Tensor& dst,
27
  const torch::Tensor& block_mapping) {
Woosuk Kwon's avatar
Woosuk Kwon committed
28
29
30
31
  torch::Device src_device = src.device();
  torch::Device dst_device = dst.device();
  cudaMemcpyKind memcpy_type;
  if (src_device.is_cuda() && dst_device.is_cuda()) {
Woosuk Kwon's avatar
Woosuk Kwon committed
32
33
34
    TORCH_CHECK(
      src_device.index() == dst_device.index(),
      "src and dst must be on the same GPU");
Woosuk Kwon's avatar
Woosuk Kwon committed
35
36
37
38
39
40
    memcpy_type = cudaMemcpyDeviceToDevice;
  } else if (src_device.is_cuda() && dst_device.is_cpu()) {
    memcpy_type = cudaMemcpyDeviceToHost;
  } else if (src_device.is_cpu() && dst_device.is_cuda()) {
    memcpy_type = cudaMemcpyHostToDevice;
  } else {
Woosuk Kwon's avatar
Woosuk Kwon committed
41
    TORCH_CHECK(false, "Invalid device combination");
Woosuk Kwon's avatar
Woosuk Kwon committed
42
43
  }

44
45
46
47
48
  // NOTE(youkaichao): keep in mind that `block_mapping` should be 
  // a cpu tensor, otherwise every `item` call will require a gpu-cpu
  // synchronization.
  TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU");

49
50
  char *src_ptr = static_cast<char*>(src.data_ptr());
  char *dst_ptr = static_cast<char*>(dst.data_ptr());
Woosuk Kwon's avatar
Woosuk Kwon committed
51
52

  const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
53
  const at::cuda::OptionalCUDAGuard device_guard(src_device.is_cuda() ? src_device : dst_device);
Woosuk Kwon's avatar
Woosuk Kwon committed
54
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
Woosuk Kwon's avatar
Woosuk Kwon committed
55
  // NOTE(woosuk): This can be slow if the number of blocks is large.
56
57
58
59
  const int64_t num_blocks = block_mapping.size(0);
  for (size_t i = 0; i < num_blocks; i++) {
    int64_t src_block_number = block_mapping[i][0].item<int64_t>();
    int64_t dst_block_number = block_mapping[i][1].item<int64_t>();
Woosuk Kwon's avatar
Woosuk Kwon committed
60
61
62
63
64
65
66
67
68
69
    int64_t src_offset = src_block_number * block_size_in_bytes;
    int64_t dst_offset = dst_block_number * block_size_in_bytes;
    cudaMemcpyAsync(
      dst_ptr + dst_offset,
      src_ptr + src_offset,
      block_size_in_bytes,
      memcpy_type,
      stream);
  }
}
Woosuk Kwon's avatar
Woosuk Kwon committed
70

Woosuk Kwon's avatar
Woosuk Kwon committed
71
namespace vllm {
72
73
74
75
76
77

// Grid: (num_layers, num_pairs)
template<typename scalar_t>
__global__ void copy_blocks_kernel(
  int64_t* key_cache_ptrs,
  int64_t* value_cache_ptrs,
78
  const int64_t* __restrict__ block_mapping,
79
80
81
82
83
84
  const int numel_per_block) {
  const int layer_idx = blockIdx.x;
  const int pair_idx = blockIdx.y;

  scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]);
  scalar_t* value_cache = reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
85
86
  int64_t src_block_number = block_mapping[2 * pair_idx];
  int64_t dst_block_number = block_mapping[2 * pair_idx + 1];
87

88
89
  const int64_t src_block_offset = src_block_number * numel_per_block;
  const int64_t dst_block_offset = dst_block_number * numel_per_block;
90
  for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
91
92
    int64_t src_offset = src_block_offset + i;
    int64_t dst_offset = dst_block_offset + i;
93
94
95
    key_cache[dst_offset] = key_cache[src_offset];
  }
  for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
96
97
    int64_t src_offset = src_block_offset + i;
    int64_t dst_offset = dst_block_offset + i;
98
99
100
101
    value_cache[dst_offset] = value_cache[src_offset];
  }
}

Woosuk Kwon's avatar
Woosuk Kwon committed
102
} // namespace vllm
103

104
void copy_blocks(
105
106
  std::vector<torch::Tensor>& key_caches,
  std::vector<torch::Tensor>& value_caches,
107
  const torch::Tensor& block_mapping) {
108
109
110
111
112
113
114
  int num_layers = key_caches.size();
  TORCH_CHECK(num_layers == value_caches.size());
  if (num_layers == 0) {
    return;
  }
  torch::Device cache_device = key_caches[0].device();
  TORCH_CHECK(cache_device.is_cuda());
115

116
117
118
119
120
121
122
123
  // Create data structures for the kernel.
  // Create an array of pointers to the key and value caches.
  int64_t key_cache_ptrs[num_layers];
  int64_t value_cache_ptrs[num_layers];
  for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
    key_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
    value_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
  }
124
125
126

  // block_mapping is a 2D tensor with shape (num_pairs, 2).
  int num_pairs = block_mapping.size(0);
127
128
129
130
131
132
133
134
135
136
137
138

  // Move the data structures to the GPU.
  // NOTE: This synchronizes the CPU and GPU.
  torch::Tensor key_cache_ptrs_tensor = torch::from_blob(
    key_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
  torch::Tensor value_cache_ptrs_tensor = torch::from_blob(
    value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);

  // Launch the kernel.
  const int numel_per_block = key_caches[0][0].numel();
  dim3 grid(num_layers, num_pairs);
  dim3 block(std::min(1024, numel_per_block));
139
  const at::cuda::OptionalCUDAGuard device_guard(cache_device);
140
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
141
  VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
142
    key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
Woosuk Kwon's avatar
Woosuk Kwon committed
143
      vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
144
145
        key_cache_ptrs_tensor.data_ptr<int64_t>(),
        value_cache_ptrs_tensor.data_ptr<int64_t>(),
146
        block_mapping.data_ptr<int64_t>(),
147
148
        numel_per_block);
    }));
149
150
}

Woosuk Kwon's avatar
Woosuk Kwon committed
151
namespace vllm {
Woosuk Kwon's avatar
Woosuk Kwon committed
152

153
template<typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
Woosuk Kwon's avatar
Woosuk Kwon committed
154
__global__ void reshape_and_cache_kernel(
155
156
  const scalar_t* __restrict__ key,           // [num_tokens, num_heads, head_size]
  const scalar_t* __restrict__ value,         // [num_tokens, num_heads, head_size]
157
158
  cache_t* __restrict__ key_cache,            // [num_blocks, num_heads, head_size/x, block_size, x]
  cache_t* __restrict__ value_cache,          // [num_blocks, num_heads, head_size, block_size]
159
  const int64_t* __restrict__ slot_mapping,   // [num_tokens]
Woosuk Kwon's avatar
Woosuk Kwon committed
160
161
  const int key_stride,
  const int value_stride,
Woosuk Kwon's avatar
Woosuk Kwon committed
162
163
164
  const int num_heads,
  const int head_size,
  const int block_size,
165
166
  const int x,
  const float kv_scale) {
167
168
  const int64_t token_idx = blockIdx.x;
  const int64_t slot_idx = slot_mapping[token_idx];
169
170
171
172
173
  if (slot_idx < 0) {
    // Padding token that should be ignored.
    return;
  }

174
175
  const int64_t block_idx = slot_idx / block_size;
  const int64_t block_offset = slot_idx % block_size;
Woosuk Kwon's avatar
Woosuk Kwon committed
176
177
178

  const int n = num_heads * head_size;
  for (int i = threadIdx.x; i < n; i += blockDim.x) {
179
180
    const int64_t src_key_idx = token_idx * key_stride + i;
    const int64_t src_value_idx = token_idx * value_stride + i;
Woosuk Kwon's avatar
Woosuk Kwon committed
181
182
183
184
185
186

    const int head_idx = i / head_size;
    const int head_offset = i % head_size;
    const int x_idx = head_offset / x;
    const int x_offset = head_offset % x;

187
188
189
190
191
192
193
194
195
    const int64_t tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
                                + head_idx * (head_size / x) * block_size * x
                                + x_idx * block_size * x
                                + block_offset * x
                                + x_offset;
    const int64_t tgt_value_idx = block_idx * num_heads * head_size * block_size
                                  + head_idx * head_size * block_size
                                  + head_offset * block_size
                                  + block_offset;
196
197
    scalar_t tgt_key = key[src_key_idx];
    scalar_t tgt_value = value[src_value_idx];
198
    if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
199
200
      key_cache[tgt_key_idx] = tgt_key;
      value_cache[tgt_value_idx] = tgt_value;
201
202
203
    } else {
      key_cache[tgt_key_idx] = fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, kv_scale);
      value_cache[tgt_value_idx] = fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, kv_scale);
204
    }
Woosuk Kwon's avatar
Woosuk Kwon committed
205
206
207
  }
}

208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
template<typename scalar_t>
__global__ void reshape_and_cache_flash_kernel(
  const scalar_t* __restrict__ key,           // [num_tokens, num_heads, head_size]
  const scalar_t* __restrict__ value,         // [num_tokens, num_heads, head_size]
  scalar_t* __restrict__ k_cache,             // [num_blocks, block_size, num_heads, head_size]
  scalar_t* __restrict__ v_cache,             // [num_blocks, block_size, num_heads, head_size]
  const int64_t* __restrict__ slot_mapping,   // [num_tokens]
  const int block_stride,
  const int key_stride,
  const int value_stride,
  const int num_heads,
  const int head_size,
  const int block_size) {
  const int64_t token_idx = blockIdx.x;
  const int64_t slot_idx = slot_mapping[token_idx];
  // NOTE: slot_idx can be -1 if the token is padded
  if (slot_idx < 0) {
    return;
  }
  const int64_t block_idx = slot_idx / block_size;
  const int64_t block_offset = slot_idx % block_size;
  const int n = num_heads * head_size;
  for (int i = threadIdx.x; i < n; i += blockDim.x) {
    const int64_t src_key_idx = token_idx * key_stride + i;
    const int64_t src_value_idx = token_idx * value_stride + i;
    const int head_idx = i / head_size;
    const int head_offset = i % head_size;
    const int64_t tgt_value_idx = block_idx * block_stride
                              + block_offset * num_heads * head_size
                              + head_idx * head_size
                              + head_offset;
    k_cache[tgt_value_idx] = key[src_key_idx];
    v_cache[tgt_value_idx] = value[src_value_idx];
  }
}
Woosuk Kwon's avatar
Woosuk Kwon committed
243
} // namespace vllm
Woosuk Kwon's avatar
Woosuk Kwon committed
244

245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
// KV_T is the stored data type of kv-cache.
// CACHE_T is the data type of key and value tensors.
// KV_DTYPE is the real data type of kv-cache.
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_DTYPE)                                     \
  vllm::reshape_and_cache_kernel<KV_T, CACHE_T, KV_DTYPE><<<grid, block, 0, stream>>>(      \
    reinterpret_cast<KV_T*>(key.data_ptr()),                                                \
    reinterpret_cast<KV_T*>(value.data_ptr()),                                              \
    reinterpret_cast<CACHE_T*>(key_cache.data_ptr()),                                       \
    reinterpret_cast<CACHE_T*>(value_cache.data_ptr()),                                     \
    slot_mapping.data_ptr<int64_t>(),                                                       \
    key_stride,                                                                             \
    value_stride,                                                                           \
    num_heads,                                                                              \
    head_size,                                                                              \
    block_size,                                                                             \
    x,                                                                                      \
261
    kv_scale);
262

Woosuk Kwon's avatar
Woosuk Kwon committed
263
264
265
266
267
void reshape_and_cache(
  torch::Tensor& key,           // [num_tokens, num_heads, head_size]
  torch::Tensor& value,         // [num_tokens, num_heads, head_size]
  torch::Tensor& key_cache,     // [num_blocks, num_heads, head_size/x, block_size, x]
  torch::Tensor& value_cache,   // [num_blocks, num_heads, head_size, block_size]
268
  torch::Tensor& slot_mapping,  // [num_tokens]
269
270
  const std::string& kv_cache_dtype,
  const float kv_scale)
Woosuk Kwon's avatar
Woosuk Kwon committed
271
272
273
274
275
276
277
278
279
280
281
282
{
  int num_tokens = key.size(0);
  int num_heads = key.size(1);
  int head_size = key.size(2);
  int block_size = key_cache.size(3);
  int x = key_cache.size(4);

  int key_stride = key.stride(0);
  int value_stride = value.stride(0);

  dim3 grid(num_tokens);
  dim3 block(std::min(num_heads * head_size, 512));
283
  const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
Woosuk Kwon's avatar
Woosuk Kwon committed
284
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
285
286

  DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype, CALL_RESHAPE_AND_CACHE)
Woosuk Kwon's avatar
Woosuk Kwon committed
287
288
}

289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
void reshape_and_cache_flash(
  torch::Tensor& key,           // [num_tokens, num_heads, head_size]
  torch::Tensor& value,         // [num_tokens, num_heads, head_size]
  torch::Tensor& k_cache,       // [num_blocks, block_size, num_heads, head_size]
  torch::Tensor& v_cache,       // [num_blocks, block_size, num_heads, head_size]
  torch::Tensor& slot_mapping,  // [num_tokens]
  const std::string& kv_cache_dtype)
{
  // FIXME: only support auto datatype, does not support fp8
  if (kv_cache_dtype != "auto") {
    TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
  }
  int num_tokens = key.size(0);
  int num_heads = key.size(1);
  int head_size = key.size(2);
  int block_size = k_cache.size(1);

  int key_stride = key.stride(0);
  int value_stride = value.stride(0);
  int block_stride = k_cache.stride(0);
  TORCH_CHECK(k_cache.stride(0) == v_cache.stride(0));

  dim3 grid(num_tokens);
  dim3 block(std::min(num_heads * head_size, 512));
  const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  VLLM_DISPATCH_FLOATING_TYPES(
    key.scalar_type(),
    "reshape_and_cache_flash",
    [&] {
      vllm::reshape_and_cache_flash_kernel<scalar_t><<<grid, block, 0, stream>>>(
        key.data_ptr<scalar_t>(),
        value.data_ptr<scalar_t>(),
        k_cache.data_ptr<scalar_t>(),
        v_cache.data_ptr<scalar_t>(),
        slot_mapping.data_ptr<int64_t>(),
        block_stride,
        key_stride,
        value_stride,
        num_heads,
        head_size,
        block_size);
    });
}

Woosuk Kwon's avatar
Woosuk Kwon committed
334
namespace vllm {
Woosuk Kwon's avatar
Woosuk Kwon committed
335

336
template<typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
337
__global__ void convert_fp8_kernel(
338
339
  const Tin* __restrict__ src_cache,
  Tout* __restrict__ dst_cache,
340
  const float kv_scale,
341
342
343
344
  const int64_t block_stride) {
  const int64_t block_idx = blockIdx.x;
  for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
    int64_t idx = block_idx * block_stride + i;
345
    dst_cache[idx] = fp8::scaled_convert<Tout, Tin, kv_dt>(src_cache[idx], kv_scale);
346
347
348
349
350
  }
}

} // namespace vllm

351
352
353
354
355
#define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE)                                 \
  vllm::convert_fp8_kernel<Tout, Tin, KV_DTYPE><<<grid, block, 0, stream>>>(  \
    reinterpret_cast<Tin*>(src_cache.data_ptr()),                             \
    reinterpret_cast<Tout*>(dst_cache.data_ptr()),                            \
    kv_scale, \
356
357
    block_stride);

358
// Only for testing.
359
void convert_fp8(
360
  torch::Tensor& dst_cache,
361
  torch::Tensor& src_cache,
362
363
  const float kv_scale,
  const std::string& kv_cache_dtype)
364
{
365
366
367
368
369
370
371
372
373
  torch::Device src_device = src_cache.device();
  torch::Device dst_device = dst_cache.device();
  TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU")
  TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU")
  TORCH_CHECK(
    src_device.index() == dst_device.index(),
    "src and dst must be on the same GPU");
  at::cuda::OptionalCUDAGuard device_guard(src_device);

374
375
376
377
378
379
380
  int64_t num_blocks = src_cache.size(0);
  int64_t block_stride = src_cache.stride(0);

  dim3 grid(num_blocks);
  dim3 block(std::min(block_stride, int64_t(512)));
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
  if (kv_cache_dtype == "auto") {
    if (src_cache.dtype() == at::ScalarType::Float) {
      CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kAuto);
    } else if (src_cache.dtype() == at::ScalarType::Half) {
      CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto);
    } else if (src_cache.dtype() == at::ScalarType::BFloat16) {
      CALL_CONVERT_FP8(uint8_t, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto);
    } else if (dst_cache.dtype() == at::ScalarType::Float) {
      CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kAuto);
    } else if (dst_cache.dtype() == at::ScalarType::Half) {
      CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kAuto);
    } else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
      CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kAuto);
    }
  } else if (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3") {
    if (src_cache.dtype() == at::ScalarType::Float) {
      CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kFp8E4M3);
    } else if (src_cache.dtype() == at::ScalarType::Half) {
      CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
    } else if (src_cache.dtype() == at::ScalarType::BFloat16) {
      CALL_CONVERT_FP8(uint8_t, __nv_bfloat16, vllm::Fp8KVCacheDataType::kFp8E4M3);
    } else if (dst_cache.dtype() == at::ScalarType::Float) {
      CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
    } else if (dst_cache.dtype() == at::ScalarType::Half) {
      CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
    } else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
      CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
    }
  } else {
    TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype);
411
412
  }
}