cache_kernels.cu 17.8 KB
Newer Older
1
#include <torch/all.h>
Woosuk Kwon's avatar
Woosuk Kwon committed
2
#include <ATen/cuda/CUDAContext.h>
3
#include <c10/cuda/CUDAGuard.h>
Woosuk Kwon's avatar
Woosuk Kwon committed
4

5
#include "cuda_compat.h"
6
#include "dispatch_utils.h"
7
8

#ifdef USE_ROCM
9
  #include "quantization/fp8/amd/quant_utils.cuh"
10
#else
11
  #include "quantization/fp8/nvidia/quant_utils.cuh"
12
#endif
13

Woosuk Kwon's avatar
Woosuk Kwon committed
14
#include <algorithm>
Woosuk Kwon's avatar
Woosuk Kwon committed
15
16
#include <cassert>
#include <map>
17
#include <vector>
Woosuk Kwon's avatar
Woosuk Kwon committed
18

19
20
#ifdef USE_ROCM
  #include <hip/hip_bf16.h>
21
typedef __hip_bfloat16 __nv_bfloat16;
22
23
#endif

24
25
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
                 const torch::Tensor& block_mapping) {
Woosuk Kwon's avatar
Woosuk Kwon committed
26
27
28
29
  torch::Device src_device = src.device();
  torch::Device dst_device = dst.device();
  cudaMemcpyKind memcpy_type;
  if (src_device.is_cuda() && dst_device.is_cuda()) {
30
31
    TORCH_CHECK(src_device.index() == dst_device.index(),
                "src and dst must be on the same GPU");
Woosuk Kwon's avatar
Woosuk Kwon committed
32
33
34
35
36
37
    memcpy_type = cudaMemcpyDeviceToDevice;
  } else if (src_device.is_cuda() && dst_device.is_cpu()) {
    memcpy_type = cudaMemcpyDeviceToHost;
  } else if (src_device.is_cpu() && dst_device.is_cuda()) {
    memcpy_type = cudaMemcpyHostToDevice;
  } else {
Woosuk Kwon's avatar
Woosuk Kwon committed
38
    TORCH_CHECK(false, "Invalid device combination");
Woosuk Kwon's avatar
Woosuk Kwon committed
39
40
  }

41
  // NOTE(youkaichao): keep in mind that `block_mapping` should be
42
43
44
45
  // a cpu tensor, otherwise every `item` call will require a gpu-cpu
  // synchronization.
  TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU");

46
47
  char* src_ptr = static_cast<char*>(src.data_ptr());
  char* dst_ptr = static_cast<char*>(dst.data_ptr());
Woosuk Kwon's avatar
Woosuk Kwon committed
48
49

  const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
50
51
  const at::cuda::OptionalCUDAGuard device_guard(
      src_device.is_cuda() ? src_device : dst_device);
Woosuk Kwon's avatar
Woosuk Kwon committed
52
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
Woosuk Kwon's avatar
Woosuk Kwon committed
53
  // NOTE(woosuk): This can be slow if the number of blocks is large.
54
55
56
57
  const int64_t num_blocks = block_mapping.size(0);
  for (size_t i = 0; i < num_blocks; i++) {
    int64_t src_block_number = block_mapping[i][0].item<int64_t>();
    int64_t dst_block_number = block_mapping[i][1].item<int64_t>();
Woosuk Kwon's avatar
Woosuk Kwon committed
58
59
    int64_t src_offset = src_block_number * block_size_in_bytes;
    int64_t dst_offset = dst_block_number * block_size_in_bytes;
60
61
    cudaMemcpyAsync(dst_ptr + dst_offset, src_ptr + src_offset,
                    block_size_in_bytes, memcpy_type, stream);
Woosuk Kwon's avatar
Woosuk Kwon committed
62
63
  }
}
Woosuk Kwon's avatar
Woosuk Kwon committed
64

Woosuk Kwon's avatar
Woosuk Kwon committed
65
namespace vllm {
66
67

// Grid: (num_layers, num_pairs)
68
69
70
71
72
template <typename scalar_t>
__global__ void copy_blocks_kernel(int64_t* key_cache_ptrs,
                                   int64_t* value_cache_ptrs,
                                   const int64_t* __restrict__ block_mapping,
                                   const int numel_per_block) {
73
74
75
76
  const int layer_idx = blockIdx.x;
  const int pair_idx = blockIdx.y;

  scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]);
77
78
  scalar_t* value_cache =
      reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
79
80
  int64_t src_block_number = block_mapping[2 * pair_idx];
  int64_t dst_block_number = block_mapping[2 * pair_idx + 1];
81

82
83
  const int64_t src_block_offset = src_block_number * numel_per_block;
  const int64_t dst_block_offset = dst_block_number * numel_per_block;
84
  for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
85
86
    int64_t src_offset = src_block_offset + i;
    int64_t dst_offset = dst_block_offset + i;
87
88
89
    key_cache[dst_offset] = key_cache[src_offset];
  }
  for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
90
91
    int64_t src_offset = src_block_offset + i;
    int64_t dst_offset = dst_block_offset + i;
92
93
94
95
    value_cache[dst_offset] = value_cache[src_offset];
  }
}

96
}  // namespace vllm
97

98
99
100
101
102
// Note: the key_caches and value_caches vectors are constant but
// not the Tensors they contain. The vectors need to be const refs
// in order to satisfy pytorch's C++ operator registration code.
void copy_blocks(std::vector<torch::Tensor> const& key_caches,
                 std::vector<torch::Tensor> const& value_caches,
103
                 const torch::Tensor& block_mapping) {
104
105
106
107
108
109
110
  int num_layers = key_caches.size();
  TORCH_CHECK(num_layers == value_caches.size());
  if (num_layers == 0) {
    return;
  }
  torch::Device cache_device = key_caches[0].device();
  TORCH_CHECK(cache_device.is_cuda());
111

112
113
114
115
116
  // Create data structures for the kernel.
  // Create an array of pointers to the key and value caches.
  int64_t key_cache_ptrs[num_layers];
  int64_t value_cache_ptrs[num_layers];
  for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
117
118
119
120
    key_cache_ptrs[layer_idx] =
        reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
    value_cache_ptrs[layer_idx] =
        reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
121
  }
122
123
124

  // block_mapping is a 2D tensor with shape (num_pairs, 2).
  int num_pairs = block_mapping.size(0);
125
126
127

  // Move the data structures to the GPU.
  // NOTE: This synchronizes the CPU and GPU.
128
129
130
131
132
133
  torch::Tensor key_cache_ptrs_tensor =
      torch::from_blob(key_cache_ptrs, {num_layers}, torch::kInt64)
          .to(cache_device);
  torch::Tensor value_cache_ptrs_tensor =
      torch::from_blob(value_cache_ptrs, {num_layers}, torch::kInt64)
          .to(cache_device);
134
135
136
137
138

  // Launch the kernel.
  const int numel_per_block = key_caches[0][0].numel();
  dim3 grid(num_layers, num_pairs);
  dim3 block(std::min(1024, numel_per_block));
139
  const at::cuda::OptionalCUDAGuard device_guard(cache_device);
140
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
141
  VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
142
143
144
145
146
147
      key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
        vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
            key_cache_ptrs_tensor.data_ptr<int64_t>(),
            value_cache_ptrs_tensor.data_ptr<int64_t>(),
            block_mapping.data_ptr<int64_t>(), numel_per_block);
      }));
148
149
}

Woosuk Kwon's avatar
Woosuk Kwon committed
150
namespace vllm {
Woosuk Kwon's avatar
Woosuk Kwon committed
151

152
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
Woosuk Kwon's avatar
Woosuk Kwon committed
153
__global__ void reshape_and_cache_kernel(
154
155
156
157
158
159
160
161
    const scalar_t* __restrict__ key,    // [num_tokens, num_heads, head_size]
    const scalar_t* __restrict__ value,  // [num_tokens, num_heads, head_size]
    cache_t* __restrict__ key_cache,     // [num_blocks, num_heads, head_size/x,
                                         // block_size, x]
    cache_t* __restrict__ value_cache,   // [num_blocks, num_heads, head_size,
                                         // block_size]
    const int64_t* __restrict__ slot_mapping,  // [num_tokens]
    const int key_stride, const int value_stride, const int num_heads,
162
163
    const int head_size, const int block_size, const int x, const float k_scale,
    const float v_scale) {
164
165
  const int64_t token_idx = blockIdx.x;
  const int64_t slot_idx = slot_mapping[token_idx];
166
167
168
169
170
  if (slot_idx < 0) {
    // Padding token that should be ignored.
    return;
  }

171
172
  const int64_t block_idx = slot_idx / block_size;
  const int64_t block_offset = slot_idx % block_size;
Woosuk Kwon's avatar
Woosuk Kwon committed
173
174
175

  const int n = num_heads * head_size;
  for (int i = threadIdx.x; i < n; i += blockDim.x) {
176
177
    const int64_t src_key_idx = token_idx * key_stride + i;
    const int64_t src_value_idx = token_idx * value_stride + i;
Woosuk Kwon's avatar
Woosuk Kwon committed
178
179
180
181
182
183

    const int head_idx = i / head_size;
    const int head_offset = i % head_size;
    const int x_idx = head_offset / x;
    const int x_offset = head_offset % x;

184
185
186
187
188
189
190
191
    const int64_t tgt_key_idx =
        block_idx * num_heads * (head_size / x) * block_size * x +
        head_idx * (head_size / x) * block_size * x + x_idx * block_size * x +
        block_offset * x + x_offset;
    const int64_t tgt_value_idx =
        block_idx * num_heads * head_size * block_size +
        head_idx * head_size * block_size + head_offset * block_size +
        block_offset;
192
193
    scalar_t tgt_key = key[src_key_idx];
    scalar_t tgt_value = value[src_value_idx];
194
    if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
195
196
      key_cache[tgt_key_idx] = tgt_key;
      value_cache[tgt_value_idx] = tgt_value;
197
    } else {
198
      key_cache[tgt_key_idx] =
199
          fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, k_scale);
200
      value_cache[tgt_value_idx] =
201
          fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, v_scale);
202
    }
Woosuk Kwon's avatar
Woosuk Kwon committed
203
204
205
  }
}

206
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
207
__global__ void reshape_and_cache_flash_kernel(
208
209
    const scalar_t* __restrict__ key,    // [num_tokens, num_heads, head_size]
    const scalar_t* __restrict__ value,  // [num_tokens, num_heads, head_size]
210
    cache_t* __restrict__ key_cache,     // [num_blocks, block_size, num_heads,
211
                                         // head_size]
212
    cache_t* __restrict__ value_cache,   // [num_blocks, block_size, num_heads,
213
214
215
                                         // head_size]
    const int64_t* __restrict__ slot_mapping,  // [num_tokens]
    const int block_stride, const int key_stride, const int value_stride,
216
217
    const int num_heads, const int head_size, const int block_size,
    const float k_scale, const float v_scale) {
218
219
220
221
222
223
224
225
226
227
228
229
230
231
  const int64_t token_idx = blockIdx.x;
  const int64_t slot_idx = slot_mapping[token_idx];
  // NOTE: slot_idx can be -1 if the token is padded
  if (slot_idx < 0) {
    return;
  }
  const int64_t block_idx = slot_idx / block_size;
  const int64_t block_offset = slot_idx % block_size;
  const int n = num_heads * head_size;
  for (int i = threadIdx.x; i < n; i += blockDim.x) {
    const int64_t src_key_idx = token_idx * key_stride + i;
    const int64_t src_value_idx = token_idx * value_stride + i;
    const int head_idx = i / head_size;
    const int head_offset = i % head_size;
232
233
234
235
236
237
238
239
240
241
242
243
244
245
    const int64_t tgt_key_value_idx = block_idx * block_stride +
                                      block_offset * num_heads * head_size +
                                      head_idx * head_size + head_offset;
    scalar_t tgt_key = key[src_key_idx];
    scalar_t tgt_value = value[src_value_idx];
    if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
      key_cache[tgt_key_value_idx] = tgt_key;
      value_cache[tgt_key_value_idx] = tgt_value;
    } else {
      key_cache[tgt_key_value_idx] =
          fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, k_scale);
      value_cache[tgt_key_value_idx] =
          fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, v_scale);
    }
246
247
  }
}
248
}  // namespace vllm
Woosuk Kwon's avatar
Woosuk Kwon committed
249

250
251
252
// KV_T is the stored data type of kv-cache.
// CACHE_T is the data type of key and value tensors.
// KV_DTYPE is the real data type of kv-cache.
253
254
255
256
257
258
259
260
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_DTYPE)               \
  vllm::reshape_and_cache_kernel<KV_T, CACHE_T, KV_DTYPE>             \
      <<<grid, block, 0, stream>>>(                                   \
          reinterpret_cast<KV_T*>(key.data_ptr()),                    \
          reinterpret_cast<KV_T*>(value.data_ptr()),                  \
          reinterpret_cast<CACHE_T*>(key_cache.data_ptr()),           \
          reinterpret_cast<CACHE_T*>(value_cache.data_ptr()),         \
          slot_mapping.data_ptr<int64_t>(), key_stride, value_stride, \
261
          num_heads, head_size, block_size, x, k_scale, v_scale);
262

Woosuk Kwon's avatar
Woosuk Kwon committed
263
void reshape_and_cache(
264
265
266
267
268
269
270
    torch::Tensor& key,    // [num_tokens, num_heads, head_size]
    torch::Tensor& value,  // [num_tokens, num_heads, head_size]
    torch::Tensor&
        key_cache,  // [num_blocks, num_heads, head_size/x, block_size, x]
    torch::Tensor&
        value_cache,  // [num_blocks, num_heads, head_size, block_size]
    torch::Tensor& slot_mapping,  // [num_tokens]
271
272
    const std::string& kv_cache_dtype, const double k_scale,
    const double v_scale) {
Woosuk Kwon's avatar
Woosuk Kwon committed
273
274
275
276
277
278
279
280
281
282
283
  int num_tokens = key.size(0);
  int num_heads = key.size(1);
  int head_size = key.size(2);
  int block_size = key_cache.size(3);
  int x = key_cache.size(4);

  int key_stride = key.stride(0);
  int value_stride = value.stride(0);

  dim3 grid(num_tokens);
  dim3 block(std::min(num_heads * head_size, 512));
284
  const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
Woosuk Kwon's avatar
Woosuk Kwon committed
285
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
286

287
288
  DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
                             CALL_RESHAPE_AND_CACHE)
Woosuk Kwon's avatar
Woosuk Kwon committed
289
290
}

291
292
293
294
295
296
297
298
299
300
301
302
303
// KV_T is the stored data type of kv-cache.
// CACHE_T is the data type of key and value tensors.
// KV_DTYPE is the real data type of kv-cache.
#define CALL_RESHAPE_AND_CACHE_FLASH(KV_T, CACHE_T, KV_DTYPE)         \
  vllm::reshape_and_cache_flash_kernel<KV_T, CACHE_T, KV_DTYPE>       \
      <<<grid, block, 0, stream>>>(                                   \
          reinterpret_cast<KV_T*>(key.data_ptr()),                    \
          reinterpret_cast<KV_T*>(value.data_ptr()),                  \
          reinterpret_cast<CACHE_T*>(key_cache.data_ptr()),           \
          reinterpret_cast<CACHE_T*>(value_cache.data_ptr()),         \
          slot_mapping.data_ptr<int64_t>(), block_stride, key_stride, \
          value_stride, num_heads, head_size, block_size, k_scale, v_scale);

304
void reshape_and_cache_flash(
305
306
307
308
309
    torch::Tensor& key,        // [num_tokens, num_heads, head_size]
    torch::Tensor& value,      // [num_tokens, num_heads, head_size]
    torch::Tensor& key_cache,  // [num_blocks, block_size, num_heads, head_size]
    torch::Tensor&
        value_cache,  // [num_blocks, block_size, num_heads, head_size]
310
    torch::Tensor& slot_mapping,  // [num_tokens] or [num_actual_tokens]
311
312
    const std::string& kv_cache_dtype, const double k_scale,
    const double v_scale) {
313
314
315
316
317
318
319
320
321
322
323
  // NOTE(woosuk): In vLLM V1, key.size(0) can be different from
  // slot_mapping.size(0) because of padding for CUDA graphs.
  // In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because
  // both include padding.
  // In vLLM V1, however, key.size(0) can be larger than slot_mapping.size(0)
  // since key includes padding for CUDA graphs, while slot_mapping does not.
  // In this case, slot_mapping.size(0) represents the actual number of tokens
  // before padding.
  // For compatibility with both cases, we use slot_mapping.size(0) as the
  // number of tokens.
  int num_tokens = slot_mapping.size(0);
324
325
  int num_heads = key.size(1);
  int head_size = key.size(2);
326
  int block_size = key_cache.size(1);
327
328
329

  int key_stride = key.stride(0);
  int value_stride = value.stride(0);
330
331
  int block_stride = key_cache.stride(0);
  TORCH_CHECK(key_cache.stride(0) == value_cache.stride(0));
332
333
334
335
336

  dim3 grid(num_tokens);
  dim3 block(std::min(num_heads * head_size, 512));
  const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
337
338
339

  DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
                             CALL_RESHAPE_AND_CACHE_FLASH);
340
341
}

Woosuk Kwon's avatar
Woosuk Kwon committed
342
namespace vllm {
Woosuk Kwon's avatar
Woosuk Kwon committed
343

344
345
346
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
__global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache,
                                   Tout* __restrict__ dst_cache,
347
                                   const float scale,
348
                                   const int64_t block_stride) {
349
350
351
  const int64_t block_idx = blockIdx.x;
  for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
    int64_t idx = block_idx * block_stride + i;
352
    dst_cache[idx] =
353
        fp8::scaled_convert<Tout, Tin, kv_dt>(src_cache[idx], scale);
354
355
356
  }
}

357
}  // namespace vllm
358

359
360
361
#define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE)                                \
  vllm::convert_fp8_kernel<Tout, Tin, KV_DTYPE><<<grid, block, 0, stream>>>( \
      reinterpret_cast<Tin*>(src_cache.data_ptr()),                          \
362
      reinterpret_cast<Tout*>(dst_cache.data_ptr()), scale, block_stride);
363

364
// Only for testing.
365
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
366
                 const double scale, const std::string& kv_cache_dtype) {
367
368
369
370
  torch::Device src_device = src_cache.device();
  torch::Device dst_device = dst_cache.device();
  TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU")
  TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU")
371
372
  TORCH_CHECK(src_device.index() == dst_device.index(),
              "src and dst must be on the same GPU");
373
374
  at::cuda::OptionalCUDAGuard device_guard(src_device);

375
376
377
378
379
380
381
  int64_t num_blocks = src_cache.size(0);
  int64_t block_stride = src_cache.stride(0);

  dim3 grid(num_blocks);
  dim3 block(std::min(block_stride, int64_t(512)));
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
  if (kv_cache_dtype == "auto") {
    if (src_cache.dtype() == at::ScalarType::Float) {
      CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kAuto);
    } else if (src_cache.dtype() == at::ScalarType::Half) {
      CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto);
    } else if (src_cache.dtype() == at::ScalarType::BFloat16) {
      CALL_CONVERT_FP8(uint8_t, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto);
    } else if (dst_cache.dtype() == at::ScalarType::Float) {
      CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kAuto);
    } else if (dst_cache.dtype() == at::ScalarType::Half) {
      CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kAuto);
    } else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
      CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kAuto);
    }
  } else if (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3") {
    if (src_cache.dtype() == at::ScalarType::Float) {
      CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kFp8E4M3);
    } else if (src_cache.dtype() == at::ScalarType::Half) {
      CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
    } else if (src_cache.dtype() == at::ScalarType::BFloat16) {
402
403
      CALL_CONVERT_FP8(uint8_t, __nv_bfloat16,
                       vllm::Fp8KVCacheDataType::kFp8E4M3);
404
405
406
407
408
    } else if (dst_cache.dtype() == at::ScalarType::Float) {
      CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
    } else if (dst_cache.dtype() == at::ScalarType::Half) {
      CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
    } else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
409
410
      CALL_CONVERT_FP8(__nv_bfloat16, uint8_t,
                       vllm::Fp8KVCacheDataType::kFp8E4M3);
411
412
413
    }
  } else {
    TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype);
414
415
  }
}