cache_kernels.cu 15.9 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
3
#include <c10/cuda/CUDAGuard.h>
Woosuk Kwon's avatar
Woosuk Kwon committed
4

5
#include "cuda_compat.h"
6
#include "dispatch_utils.h"
7
#if defined(ENABLE_FP8_E5M2)
8
#include "quantization/fp8_e5m2_kvcache/quant_utils.cuh"
9
10
#elif defined(ENABLE_FP8_E4M3)
#include "quantization/fp8/amd_detail/quant_utils.cuh"
11
#endif
12

Woosuk Kwon's avatar
Woosuk Kwon committed
13
#include <algorithm>
Woosuk Kwon's avatar
Woosuk Kwon committed
14
15
#include <cassert>
#include <map>
16
#include <vector>
Woosuk Kwon's avatar
Woosuk Kwon committed
17

18
19
20
21
22
#ifdef USE_ROCM
  #include <hip/hip_bf16.h>
  typedef __hip_bfloat16 __nv_bfloat16;
#endif

23
void swap_blocks(
Woosuk Kwon's avatar
Woosuk Kwon committed
24
25
26
27
28
29
30
  torch::Tensor& src,
  torch::Tensor& dst,
  const std::map<int64_t, int64_t>& block_mapping) {
  torch::Device src_device = src.device();
  torch::Device dst_device = dst.device();
  cudaMemcpyKind memcpy_type;
  if (src_device.is_cuda() && dst_device.is_cuda()) {
Woosuk Kwon's avatar
Woosuk Kwon committed
31
32
33
    TORCH_CHECK(
      src_device.index() == dst_device.index(),
      "src and dst must be on the same GPU");
Woosuk Kwon's avatar
Woosuk Kwon committed
34
35
36
37
38
39
    memcpy_type = cudaMemcpyDeviceToDevice;
  } else if (src_device.is_cuda() && dst_device.is_cpu()) {
    memcpy_type = cudaMemcpyDeviceToHost;
  } else if (src_device.is_cpu() && dst_device.is_cuda()) {
    memcpy_type = cudaMemcpyHostToDevice;
  } else {
Woosuk Kwon's avatar
Woosuk Kwon committed
40
    TORCH_CHECK(false, "Invalid device combination");
Woosuk Kwon's avatar
Woosuk Kwon committed
41
42
  }

43
44
  char *src_ptr = static_cast<char*>(src.data_ptr());
  char *dst_ptr = static_cast<char*>(dst.data_ptr());
Woosuk Kwon's avatar
Woosuk Kwon committed
45
46

  const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
47
  const at::cuda::OptionalCUDAGuard device_guard(src_device.is_cuda() ? src_device : dst_device);
Woosuk Kwon's avatar
Woosuk Kwon committed
48
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
Woosuk Kwon's avatar
Woosuk Kwon committed
49
  // NOTE(woosuk): This can be slow if the number of blocks is large.
Woosuk Kwon's avatar
Woosuk Kwon committed
50
51
52
53
54
55
56
57
58
59
60
61
62
  for (const auto& pair : block_mapping) {
    int64_t src_block_number = pair.first;
    int64_t dst_block_number = pair.second;
    int64_t src_offset = src_block_number * block_size_in_bytes;
    int64_t dst_offset = dst_block_number * block_size_in_bytes;
    cudaMemcpyAsync(
      dst_ptr + dst_offset,
      src_ptr + src_offset,
      block_size_in_bytes,
      memcpy_type,
      stream);
  }
}
Woosuk Kwon's avatar
Woosuk Kwon committed
63

Woosuk Kwon's avatar
Woosuk Kwon committed
64
namespace vllm {
65
66
67
68
69
70

// Grid: (num_layers, num_pairs)
template<typename scalar_t>
__global__ void copy_blocks_kernel(
  int64_t* key_cache_ptrs,
  int64_t* value_cache_ptrs,
71
  const int64_t* __restrict__ block_mapping,
72
73
74
75
76
77
  const int numel_per_block) {
  const int layer_idx = blockIdx.x;
  const int pair_idx = blockIdx.y;

  scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]);
  scalar_t* value_cache = reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
78
79
  int64_t src_block_number = block_mapping[2 * pair_idx];
  int64_t dst_block_number = block_mapping[2 * pair_idx + 1];
80

81
82
  const int64_t src_block_offset = src_block_number * numel_per_block;
  const int64_t dst_block_offset = dst_block_number * numel_per_block;
83
  for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
84
85
    int64_t src_offset = src_block_offset + i;
    int64_t dst_offset = dst_block_offset + i;
86
87
88
    key_cache[dst_offset] = key_cache[src_offset];
  }
  for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
89
90
    int64_t src_offset = src_block_offset + i;
    int64_t dst_offset = dst_block_offset + i;
91
92
93
94
    value_cache[dst_offset] = value_cache[src_offset];
  }
}

Woosuk Kwon's avatar
Woosuk Kwon committed
95
} // namespace vllm
96

97
void copy_blocks(
98
99
  std::vector<torch::Tensor>& key_caches,
  std::vector<torch::Tensor>& value_caches,
100
  torch::Tensor& block_mapping) {
101
102
103
104
105
106
107
  int num_layers = key_caches.size();
  TORCH_CHECK(num_layers == value_caches.size());
  if (num_layers == 0) {
    return;
  }
  torch::Device cache_device = key_caches[0].device();
  TORCH_CHECK(cache_device.is_cuda());
108

109
110
111
112
113
114
115
116
  // Create data structures for the kernel.
  // Create an array of pointers to the key and value caches.
  int64_t key_cache_ptrs[num_layers];
  int64_t value_cache_ptrs[num_layers];
  for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
    key_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
    value_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
  }
117
118
119

  // block_mapping is a 2D tensor with shape (num_pairs, 2).
  int num_pairs = block_mapping.size(0);
120
121
122
123
124
125
126
127
128
129
130
131

  // Move the data structures to the GPU.
  // NOTE: This synchronizes the CPU and GPU.
  torch::Tensor key_cache_ptrs_tensor = torch::from_blob(
    key_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
  torch::Tensor value_cache_ptrs_tensor = torch::from_blob(
    value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);

  // Launch the kernel.
  const int numel_per_block = key_caches[0][0].numel();
  dim3 grid(num_layers, num_pairs);
  dim3 block(std::min(1024, numel_per_block));
132
  const at::cuda::OptionalCUDAGuard device_guard(cache_device);
133
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
134
  VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
135
    key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
Woosuk Kwon's avatar
Woosuk Kwon committed
136
      vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
137
138
        key_cache_ptrs_tensor.data_ptr<int64_t>(),
        value_cache_ptrs_tensor.data_ptr<int64_t>(),
139
        block_mapping.data_ptr<int64_t>(),
140
141
        numel_per_block);
    }));
142
143
}

Woosuk Kwon's avatar
Woosuk Kwon committed
144
namespace vllm {
Woosuk Kwon's avatar
Woosuk Kwon committed
145

146
template<typename scalar_t, typename cache_t, bool is_fp8_kv_cache>
Woosuk Kwon's avatar
Woosuk Kwon committed
147
__global__ void reshape_and_cache_kernel(
148
149
  const scalar_t* __restrict__ key,           // [num_tokens, num_heads, head_size]
  const scalar_t* __restrict__ value,         // [num_tokens, num_heads, head_size]
150
151
  cache_t* __restrict__ key_cache,            // [num_blocks, num_heads, head_size/x, block_size, x]
  cache_t* __restrict__ value_cache,          // [num_blocks, num_heads, head_size, block_size]
152
  const int64_t* __restrict__ slot_mapping,   // [num_tokens]
Woosuk Kwon's avatar
Woosuk Kwon committed
153
154
  const int key_stride,
  const int value_stride,
Woosuk Kwon's avatar
Woosuk Kwon committed
155
156
157
  const int num_heads,
  const int head_size,
  const int block_size,
158
159
  const int x,
  const float kv_scale) {
160
161
  const int64_t token_idx = blockIdx.x;
  const int64_t slot_idx = slot_mapping[token_idx];
162
163
164
165
166
  if (slot_idx < 0) {
    // Padding token that should be ignored.
    return;
  }

167
168
  const int64_t block_idx = slot_idx / block_size;
  const int64_t block_offset = slot_idx % block_size;
Woosuk Kwon's avatar
Woosuk Kwon committed
169
170
171

  const int n = num_heads * head_size;
  for (int i = threadIdx.x; i < n; i += blockDim.x) {
172
173
    const int64_t src_key_idx = token_idx * key_stride + i;
    const int64_t src_value_idx = token_idx * value_stride + i;
Woosuk Kwon's avatar
Woosuk Kwon committed
174
175
176
177
178
179

    const int head_idx = i / head_size;
    const int head_offset = i % head_size;
    const int x_idx = head_offset / x;
    const int x_offset = head_offset % x;

180
181
182
183
184
185
186
187
188
    const int64_t tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
                                + head_idx * (head_size / x) * block_size * x
                                + x_idx * block_size * x
                                + block_offset * x
                                + x_offset;
    const int64_t tgt_value_idx = block_idx * num_heads * head_size * block_size
                                  + head_idx * head_size * block_size
                                  + head_offset * block_size
                                  + block_offset;
189
190
    scalar_t tgt_key = key[src_key_idx];
    scalar_t tgt_value = value[src_value_idx];
191
192
    if constexpr (is_fp8_kv_cache) {
#if defined(ENABLE_FP8_E5M2)
193
194
      key_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_key);
      value_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_value);
195
196
197
#elif defined(ENABLE_FP8_E4M3)
      key_cache[tgt_key_idx] = fp8_e4m3::scaled_vec_conversion<uint8_t, scalar_t>(tgt_key, kv_scale);
      value_cache[tgt_value_idx] = fp8_e4m3::scaled_vec_conversion<uint8_t, scalar_t>(tgt_value, kv_scale);
198
199
200
201
202
203
204
#else
      assert(false);
#endif
    } else {
      key_cache[tgt_key_idx] = tgt_key;
      value_cache[tgt_value_idx] = tgt_value;
    }
Woosuk Kwon's avatar
Woosuk Kwon committed
205
206
207
  }
}

208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
template<typename scalar_t>
__global__ void reshape_and_cache_flash_kernel(
  const scalar_t* __restrict__ key,           // [num_tokens, num_heads, head_size]
  const scalar_t* __restrict__ value,         // [num_tokens, num_heads, head_size]
  scalar_t* __restrict__ k_cache,             // [num_blocks, block_size, num_heads, head_size]
  scalar_t* __restrict__ v_cache,             // [num_blocks, block_size, num_heads, head_size]
  const int64_t* __restrict__ slot_mapping,   // [num_tokens]
  const int block_stride,
  const int key_stride,
  const int value_stride,
  const int num_heads,
  const int head_size,
  const int block_size) {
  const int64_t token_idx = blockIdx.x;
  const int64_t slot_idx = slot_mapping[token_idx];
  // NOTE: slot_idx can be -1 if the token is padded
  if (slot_idx < 0) {
    return;
  }
  const int64_t block_idx = slot_idx / block_size;
  const int64_t block_offset = slot_idx % block_size;
  const int n = num_heads * head_size;
  for (int i = threadIdx.x; i < n; i += blockDim.x) {
    const int64_t src_key_idx = token_idx * key_stride + i;
    const int64_t src_value_idx = token_idx * value_stride + i;
    const int head_idx = i / head_size;
    const int head_offset = i % head_size;
    const int64_t tgt_value_idx = block_idx * block_stride
                              + block_offset * num_heads * head_size
                              + head_idx * head_size
                              + head_offset;
    k_cache[tgt_value_idx] = key[src_key_idx];
    v_cache[tgt_value_idx] = value[src_value_idx];
  }
}
Woosuk Kwon's avatar
Woosuk Kwon committed
243
} // namespace vllm
Woosuk Kwon's avatar
Woosuk Kwon committed
244

245
246
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_KV_CACHE)                                     \
  vllm::reshape_and_cache_kernel<KV_T, CACHE_T, IS_FP8_KV_CACHE><<<grid, block, 0, stream>>>(      \
247
248
249
250
251
252
253
254
255
256
    reinterpret_cast<KV_T*>(key.data_ptr()),                                                       \
    reinterpret_cast<KV_T*>(value.data_ptr()),                                                     \
    reinterpret_cast<CACHE_T*>(key_cache.data_ptr()),                                              \
    reinterpret_cast<CACHE_T*>(value_cache.data_ptr()),                                            \
    slot_mapping.data_ptr<int64_t>(),                                                              \
    key_stride,                                                                                    \
    value_stride,                                                                                  \
    num_heads,                                                                                     \
    head_size,                                                                                     \
    block_size,                                                                                    \
257
258
    x,                                                                                             \
    kv_scale);
259

Woosuk Kwon's avatar
Woosuk Kwon committed
260
261
262
263
264
void reshape_and_cache(
  torch::Tensor& key,           // [num_tokens, num_heads, head_size]
  torch::Tensor& value,         // [num_tokens, num_heads, head_size]
  torch::Tensor& key_cache,     // [num_blocks, num_heads, head_size/x, block_size, x]
  torch::Tensor& value_cache,   // [num_blocks, num_heads, head_size, block_size]
265
  torch::Tensor& slot_mapping,  // [num_tokens]
266
267
  const std::string& kv_cache_dtype,
  const float kv_scale)
Woosuk Kwon's avatar
Woosuk Kwon committed
268
269
270
271
272
273
274
275
276
277
278
279
{
  int num_tokens = key.size(0);
  int num_heads = key.size(1);
  int head_size = key.size(2);
  int block_size = key_cache.size(3);
  int x = key_cache.size(4);

  int key_stride = key.stride(0);
  int value_stride = value.stride(0);

  dim3 grid(num_tokens);
  dim3 block(std::min(num_heads * head_size, 512));
280
  const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
Woosuk Kwon's avatar
Woosuk Kwon committed
281
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
282
283
284
285
286
287
288
289
  if (kv_cache_dtype == "auto") {
    if (key.dtype() == at::ScalarType::Float) {
      CALL_RESHAPE_AND_CACHE(float, float, false);
    } else if (key.dtype() == at::ScalarType::Half) {
      CALL_RESHAPE_AND_CACHE(uint16_t, uint16_t, false);
    } else if (key.dtype() == at::ScalarType::BFloat16) {
      CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false);
    }
290
  } else if (kv_cache_dtype == "fp8") {
291
292
293
294
295
296
297
298
299
300
    if (key.dtype() == at::ScalarType::Float) {
      CALL_RESHAPE_AND_CACHE(float, uint8_t, true);
    } else if (key.dtype() == at::ScalarType::Half) {
      CALL_RESHAPE_AND_CACHE(uint16_t, uint8_t, true);
    } else if (key.dtype() == at::ScalarType::BFloat16) {
      CALL_RESHAPE_AND_CACHE(__nv_bfloat16, uint8_t, true);
    }
  } else {
    TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
  }
Woosuk Kwon's avatar
Woosuk Kwon committed
301
302
}

303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
void reshape_and_cache_flash(
  torch::Tensor& key,           // [num_tokens, num_heads, head_size]
  torch::Tensor& value,         // [num_tokens, num_heads, head_size]
  torch::Tensor& k_cache,       // [num_blocks, block_size, num_heads, head_size]
  torch::Tensor& v_cache,       // [num_blocks, block_size, num_heads, head_size]
  torch::Tensor& slot_mapping,  // [num_tokens]
  const std::string& kv_cache_dtype)
{
  // FIXME: only support auto datatype, does not support fp8
  if (kv_cache_dtype != "auto") {
    TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
  }
  int num_tokens = key.size(0);
  int num_heads = key.size(1);
  int head_size = key.size(2);
  int block_size = k_cache.size(1);

  int key_stride = key.stride(0);
  int value_stride = value.stride(0);
  int block_stride = k_cache.stride(0);
  TORCH_CHECK(k_cache.stride(0) == v_cache.stride(0));

  dim3 grid(num_tokens);
  dim3 block(std::min(num_heads * head_size, 512));
  const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  VLLM_DISPATCH_FLOATING_TYPES(
    key.scalar_type(),
    "reshape_and_cache_flash",
    [&] {
      vllm::reshape_and_cache_flash_kernel<scalar_t><<<grid, block, 0, stream>>>(
        key.data_ptr<scalar_t>(),
        value.data_ptr<scalar_t>(),
        k_cache.data_ptr<scalar_t>(),
        v_cache.data_ptr<scalar_t>(),
        slot_mapping.data_ptr<int64_t>(),
        block_stride,
        key_stride,
        value_stride,
        num_heads,
        head_size,
        block_size);
    });
}

Woosuk Kwon's avatar
Woosuk Kwon committed
348
namespace vllm {
Woosuk Kwon's avatar
Woosuk Kwon committed
349

350
template<typename Tout, typename Tin>
351
__global__ void convert_fp8_kernel(
352
353
354
355
356
357
  const Tin* __restrict__ src_cache,
  Tout* __restrict__ dst_cache,
  const int64_t block_stride) {
  const int64_t block_idx = blockIdx.x;
  for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
    int64_t idx = block_idx * block_stride + i;
358
#if defined(ENABLE_FP8_E5M2)
359
    dst_cache[idx] = fp8_e5m2_unscaled::vec_conversion<Tout, Tin>(src_cache[idx]);
360
361
#elif defined(ENABLE_FP8_E4M3)
    dst_cache[idx] = fp8_e4m3::vec_conversion<Tout, Tin>(src_cache[idx]);
362
363
364
365
366
367
368
369
#else
    assert(false);
#endif
  }
}

} // namespace vllm

370
371
372
373
#define CALL_CONVERT_FP8(Tout, Tin)                                 \
  vllm::convert_fp8_kernel<Tout, Tin><<<grid, block, 0, stream>>>(  \
    reinterpret_cast<Tin*>(src_cache.data_ptr()),                   \
    reinterpret_cast<Tout*>(dst_cache.data_ptr()),                  \
374
375
    block_stride);

376
void convert_fp8(
377
378
379
  torch::Tensor& src_cache,
  torch::Tensor& dst_cache)
{
380
381
382
383
384
385
386
387
388
  torch::Device src_device = src_cache.device();
  torch::Device dst_device = dst_cache.device();
  TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU")
  TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU")
  TORCH_CHECK(
    src_device.index() == dst_device.index(),
    "src and dst must be on the same GPU");
  at::cuda::OptionalCUDAGuard device_guard(src_device);

389
390
391
392
393
394
395
396
  int64_t num_blocks = src_cache.size(0);
  int64_t block_stride = src_cache.stride(0);

  dim3 grid(num_blocks);
  dim3 block(std::min(block_stride, int64_t(512)));
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

  if (src_cache.dtype() == at::ScalarType::Float) {
397
    CALL_CONVERT_FP8(uint8_t, float);
398
  } else if (src_cache.dtype() == at::ScalarType::Half) {
399
    CALL_CONVERT_FP8(uint8_t, uint16_t);
400
  } else if (src_cache.dtype() == at::ScalarType::BFloat16) {
401
    CALL_CONVERT_FP8(uint8_t, __nv_bfloat16);
402
  } else if (dst_cache.dtype() == at::ScalarType::Float) {
403
    CALL_CONVERT_FP8(float, uint8_t);
404
  } else if (dst_cache.dtype() == at::ScalarType::Half) {
405
    CALL_CONVERT_FP8(uint16_t, uint8_t);
406
  } else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
407
    CALL_CONVERT_FP8(__nv_bfloat16, uint8_t);
408
409
  }
}