cache_kernels.cu 12.5 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
3
#include <c10/cuda/CUDAGuard.h>
Woosuk Kwon's avatar
Woosuk Kwon committed
4

5
#include "cuda_compat.h"
6
#include "dispatch_utils.h"
7
#ifdef ENABLE_FP8_E5M2
8
#include "quantization/fp8_e5m2_kvcache/quant_utils.cuh"
9
#endif
10

Woosuk Kwon's avatar
Woosuk Kwon committed
11
#include <algorithm>
Woosuk Kwon's avatar
Woosuk Kwon committed
12
13
#include <cassert>
#include <map>
14
#include <vector>
Woosuk Kwon's avatar
Woosuk Kwon committed
15

zhuwenwen's avatar
zhuwenwen committed
16
17
18
19
// #ifdef USE_ROCM
//   #include <hip/hip_bf16.h>
//   typedef __hip_bfloat16 __nv_bfloat16;
// #endif
20

21
void swap_blocks(
Woosuk Kwon's avatar
Woosuk Kwon committed
22
23
24
25
26
27
28
  torch::Tensor& src,
  torch::Tensor& dst,
  const std::map<int64_t, int64_t>& block_mapping) {
  torch::Device src_device = src.device();
  torch::Device dst_device = dst.device();
  cudaMemcpyKind memcpy_type;
  if (src_device.is_cuda() && dst_device.is_cuda()) {
Woosuk Kwon's avatar
Woosuk Kwon committed
29
30
31
    TORCH_CHECK(
      src_device.index() == dst_device.index(),
      "src and dst must be on the same GPU");
Woosuk Kwon's avatar
Woosuk Kwon committed
32
33
34
35
36
37
    memcpy_type = cudaMemcpyDeviceToDevice;
  } else if (src_device.is_cuda() && dst_device.is_cpu()) {
    memcpy_type = cudaMemcpyDeviceToHost;
  } else if (src_device.is_cpu() && dst_device.is_cuda()) {
    memcpy_type = cudaMemcpyHostToDevice;
  } else {
Woosuk Kwon's avatar
Woosuk Kwon committed
38
    TORCH_CHECK(false, "Invalid device combination");
Woosuk Kwon's avatar
Woosuk Kwon committed
39
40
  }

41
42
  char *src_ptr = static_cast<char*>(src.data_ptr());
  char *dst_ptr = static_cast<char*>(dst.data_ptr());
Woosuk Kwon's avatar
Woosuk Kwon committed
43
44

  const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
45
  const at::cuda::OptionalCUDAGuard device_guard(src_device.is_cuda() ? src_device : dst_device);
Woosuk Kwon's avatar
Woosuk Kwon committed
46
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
Woosuk Kwon's avatar
Woosuk Kwon committed
47
  // NOTE(woosuk): This can be slow if the number of blocks is large.
Woosuk Kwon's avatar
Woosuk Kwon committed
48
49
50
51
52
53
54
55
56
57
58
59
60
  for (const auto& pair : block_mapping) {
    int64_t src_block_number = pair.first;
    int64_t dst_block_number = pair.second;
    int64_t src_offset = src_block_number * block_size_in_bytes;
    int64_t dst_offset = dst_block_number * block_size_in_bytes;
    cudaMemcpyAsync(
      dst_ptr + dst_offset,
      src_ptr + src_offset,
      block_size_in_bytes,
      memcpy_type,
      stream);
  }
}
Woosuk Kwon's avatar
Woosuk Kwon committed
61

Woosuk Kwon's avatar
Woosuk Kwon committed
62
namespace vllm {
63
64
65
66
67
68

// Grid: (num_layers, num_pairs)
template<typename scalar_t>
__global__ void copy_blocks_kernel(
  int64_t* key_cache_ptrs,
  int64_t* value_cache_ptrs,
69
  const int64_t* __restrict__ block_mapping,
70
71
72
73
74
75
  const int numel_per_block) {
  const int layer_idx = blockIdx.x;
  const int pair_idx = blockIdx.y;

  scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]);
  scalar_t* value_cache = reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
76
77
  int64_t src_block_number = block_mapping[2 * pair_idx];
  int64_t dst_block_number = block_mapping[2 * pair_idx + 1];
78

79
80
  const int64_t src_block_offset = src_block_number * numel_per_block;
  const int64_t dst_block_offset = dst_block_number * numel_per_block;
81
  for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
82
83
    int64_t src_offset = src_block_offset + i;
    int64_t dst_offset = dst_block_offset + i;
84
85
86
    key_cache[dst_offset] = key_cache[src_offset];
  }
  for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
87
88
    int64_t src_offset = src_block_offset + i;
    int64_t dst_offset = dst_block_offset + i;
89
90
91
92
    value_cache[dst_offset] = value_cache[src_offset];
  }
}

Woosuk Kwon's avatar
Woosuk Kwon committed
93
} // namespace vllm
94

95
void copy_blocks(
96
97
  std::vector<torch::Tensor>& key_caches,
  std::vector<torch::Tensor>& value_caches,
98
  const std::map<int64_t, std::vector<int64_t>>& block_mapping) {
99
100
101
102
103
104
105
  int num_layers = key_caches.size();
  TORCH_CHECK(num_layers == value_caches.size());
  if (num_layers == 0) {
    return;
  }
  torch::Device cache_device = key_caches[0].device();
  TORCH_CHECK(cache_device.is_cuda());
106

107
108
109
110
111
112
113
114
115
  // Create data structures for the kernel.
  // Create an array of pointers to the key and value caches.
  int64_t key_cache_ptrs[num_layers];
  int64_t value_cache_ptrs[num_layers];
  for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
    key_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
    value_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
  }
  // Create block mapping array.
116
  std::vector<int64_t> block_mapping_vec;
117
  for (const auto& pair : block_mapping) {
118
119
    int64_t src_block_number = pair.first;
    for (int64_t dst_block_number : pair.second) {
120
121
      block_mapping_vec.push_back(src_block_number);
      block_mapping_vec.push_back(dst_block_number);
122
123
    }
  }
124
  int64_t* block_mapping_array = block_mapping_vec.data();
125
126
127
128
129
130
131
132
133
  int num_pairs = block_mapping_vec.size() / 2;

  // Move the data structures to the GPU.
  // NOTE: This synchronizes the CPU and GPU.
  torch::Tensor key_cache_ptrs_tensor = torch::from_blob(
    key_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
  torch::Tensor value_cache_ptrs_tensor = torch::from_blob(
    value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
  torch::Tensor block_mapping_tensor = torch::from_blob(
134
    block_mapping_array, {2 * num_pairs}, torch::kInt64).to(cache_device);
135
136
137
138
139

  // Launch the kernel.
  const int numel_per_block = key_caches[0][0].numel();
  dim3 grid(num_layers, num_pairs);
  dim3 block(std::min(1024, numel_per_block));
140
  const at::cuda::OptionalCUDAGuard device_guard(cache_device);
141
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
142
  VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
143
    key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
Woosuk Kwon's avatar
Woosuk Kwon committed
144
      vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
145
146
        key_cache_ptrs_tensor.data_ptr<int64_t>(),
        value_cache_ptrs_tensor.data_ptr<int64_t>(),
147
        block_mapping_tensor.data_ptr<int64_t>(),
148
149
        numel_per_block);
    }));
150
151
}

Woosuk Kwon's avatar
Woosuk Kwon committed
152
namespace vllm {
Woosuk Kwon's avatar
Woosuk Kwon committed
153

154
template<typename scalar_t, typename cache_t, bool is_fp8_e5m2_kv_cache>
Woosuk Kwon's avatar
Woosuk Kwon committed
155
__global__ void reshape_and_cache_kernel(
156
157
  const scalar_t* __restrict__ key,           // [num_tokens, num_heads, head_size]
  const scalar_t* __restrict__ value,         // [num_tokens, num_heads, head_size]
158
159
  cache_t* __restrict__ key_cache,            // [num_blocks, num_heads, head_size/x, block_size, x]
  cache_t* __restrict__ value_cache,          // [num_blocks, num_heads, head_size, block_size]
160
  const int64_t* __restrict__ slot_mapping,   // [num_tokens]
Woosuk Kwon's avatar
Woosuk Kwon committed
161
162
  const int key_stride,
  const int value_stride,
Woosuk Kwon's avatar
Woosuk Kwon committed
163
164
165
166
  const int num_heads,
  const int head_size,
  const int block_size,
  const int x) {
167
168
  const int64_t token_idx = blockIdx.x;
  const int64_t slot_idx = slot_mapping[token_idx];
169
170
171
172
173
  if (slot_idx < 0) {
    // Padding token that should be ignored.
    return;
  }

174
175
  const int64_t block_idx = slot_idx / block_size;
  const int64_t block_offset = slot_idx % block_size;
Woosuk Kwon's avatar
Woosuk Kwon committed
176
177
178

  const int n = num_heads * head_size;
  for (int i = threadIdx.x; i < n; i += blockDim.x) {
179
180
    const int64_t src_key_idx = token_idx * key_stride + i;
    const int64_t src_value_idx = token_idx * value_stride + i;
Woosuk Kwon's avatar
Woosuk Kwon committed
181
182
183
184
185
186

    const int head_idx = i / head_size;
    const int head_offset = i % head_size;
    const int x_idx = head_offset / x;
    const int x_offset = head_offset % x;

187
188
189
190
191
192
193
194
195
    const int64_t tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
                                + head_idx * (head_size / x) * block_size * x
                                + x_idx * block_size * x
                                + block_offset * x
                                + x_offset;
    const int64_t tgt_value_idx = block_idx * num_heads * head_size * block_size
                                  + head_idx * head_size * block_size
                                  + head_offset * block_size
                                  + block_offset;
196
197
198
199
200
201
202
203
204
205
206
207
208
    scalar_t tgt_key = key[src_key_idx];
    scalar_t tgt_value = value[src_value_idx];
    if constexpr (is_fp8_e5m2_kv_cache) {
#ifdef ENABLE_FP8_E5M2
      key_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_key);
      value_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_value);
#else
      assert(false);
#endif
    } else {
      key_cache[tgt_key_idx] = tgt_key;
      value_cache[tgt_value_idx] = tgt_value;
    }
Woosuk Kwon's avatar
Woosuk Kwon committed
209
210
211
  }
}

Woosuk Kwon's avatar
Woosuk Kwon committed
212
} // namespace vllm
Woosuk Kwon's avatar
Woosuk Kwon committed
213

214
215
216
217
218
219
220
221
222
223
224
225
226
227
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_E5M2_KV_CACHE)                                \
  vllm::reshape_and_cache_kernel<KV_T, CACHE_T, IS_FP8_E5M2_KV_CACHE><<<grid, block, 0, stream>>>( \
    reinterpret_cast<KV_T*>(key.data_ptr()),                                                       \
    reinterpret_cast<KV_T*>(value.data_ptr()),                                                     \
    reinterpret_cast<CACHE_T*>(key_cache.data_ptr()),                                              \
    reinterpret_cast<CACHE_T*>(value_cache.data_ptr()),                                            \
    slot_mapping.data_ptr<int64_t>(),                                                              \
    key_stride,                                                                                    \
    value_stride,                                                                                  \
    num_heads,                                                                                     \
    head_size,                                                                                     \
    block_size,                                                                                    \
    x);

Woosuk Kwon's avatar
Woosuk Kwon committed
228
229
230
231
232
void reshape_and_cache(
  torch::Tensor& key,           // [num_tokens, num_heads, head_size]
  torch::Tensor& value,         // [num_tokens, num_heads, head_size]
  torch::Tensor& key_cache,     // [num_blocks, num_heads, head_size/x, block_size, x]
  torch::Tensor& value_cache,   // [num_blocks, num_heads, head_size, block_size]
233
234
  torch::Tensor& slot_mapping,  // [num_tokens]
  const std::string& kv_cache_dtype)
Woosuk Kwon's avatar
Woosuk Kwon committed
235
236
237
238
239
240
241
242
243
244
245
246
{
  int num_tokens = key.size(0);
  int num_heads = key.size(1);
  int head_size = key.size(2);
  int block_size = key_cache.size(3);
  int x = key_cache.size(4);

  int key_stride = key.stride(0);
  int value_stride = value.stride(0);

  dim3 grid(num_tokens);
  dim3 block(std::min(num_heads * head_size, 512));
247
  const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
Woosuk Kwon's avatar
Woosuk Kwon committed
248
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
249
250
251
252
253
  if (kv_cache_dtype == "auto") {
    if (key.dtype() == at::ScalarType::Float) {
      CALL_RESHAPE_AND_CACHE(float, float, false);
    } else if (key.dtype() == at::ScalarType::Half) {
      CALL_RESHAPE_AND_CACHE(uint16_t, uint16_t, false);
zhuwenwen's avatar
zhuwenwen committed
254
255
    // } else if (key.dtype() == at::ScalarType::BFloat16) {
    //   CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false);
256
    }
zhuwenwen's avatar
zhuwenwen committed
257
258
259
260
261
262
263
264
  // } else if (kv_cache_dtype == "fp8_e5m2") {
  //   if (key.dtype() == at::ScalarType::Float) {
  //     CALL_RESHAPE_AND_CACHE(float, uint8_t, true);
  //   } else if (key.dtype() == at::ScalarType::Half) {
  //     CALL_RESHAPE_AND_CACHE(uint16_t, uint8_t, true);
  //   } else if (key.dtype() == at::ScalarType::BFloat16) {
  //     CALL_RESHAPE_AND_CACHE(__nv_bfloat16, uint8_t, true);
  //   }
265
266
267
  } else {
    TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
  }
Woosuk Kwon's avatar
Woosuk Kwon committed
268
269
}

Woosuk Kwon's avatar
Woosuk Kwon committed
270
namespace vllm {
Woosuk Kwon's avatar
Woosuk Kwon committed
271

272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
template<typename Tout, typename Tin>
__global__ void convert_fp8_e5m2_kernel(
  const Tin* __restrict__ src_cache,
  Tout* __restrict__ dst_cache,
  const int64_t block_stride) {
  const int64_t block_idx = blockIdx.x;
  for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
    int64_t idx = block_idx * block_stride + i;
#ifdef ENABLE_FP8_E5M2
    dst_cache[idx] = fp8_e5m2_unscaled::vec_conversion<Tout, Tin>(src_cache[idx]);
#else
    assert(false);
#endif
  }
}

} // namespace vllm

#define CALL_CONVERT_FP8_E5M2(Tout, Tin)                                 \
  vllm::convert_fp8_e5m2_kernel<Tout, Tin><<<grid, block, 0, stream>>>(  \
    reinterpret_cast<Tin*>(src_cache.data_ptr()),                        \
    reinterpret_cast<Tout*>(dst_cache.data_ptr()),                       \
    block_stride);

void convert_fp8_e5m2(
  torch::Tensor& src_cache,
  torch::Tensor& dst_cache)
{
  int64_t num_blocks = src_cache.size(0);
  int64_t block_stride = src_cache.stride(0);

  dim3 grid(num_blocks);
  dim3 block(std::min(block_stride, int64_t(512)));
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

  if (src_cache.dtype() == at::ScalarType::Float) {
    CALL_CONVERT_FP8_E5M2(uint8_t, float);
  } else if (src_cache.dtype() == at::ScalarType::Half) {
    CALL_CONVERT_FP8_E5M2(uint8_t, uint16_t);
zhuwenwen's avatar
zhuwenwen committed
311
312
  // } else if (src_cache.dtype() == at::ScalarType::BFloat16) {
  //   CALL_CONVERT_FP8_E5M2(uint8_t, __nv_bfloat16);
313
314
315
316
  } else if (dst_cache.dtype() == at::ScalarType::Float) {
    CALL_CONVERT_FP8_E5M2(float, uint8_t);
  } else if (dst_cache.dtype() == at::ScalarType::Half) {
    CALL_CONVERT_FP8_E5M2(uint16_t, uint8_t);
zhuwenwen's avatar
zhuwenwen committed
317
318
  // } else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
  //   CALL_CONVERT_FP8_E5M2(__nv_bfloat16, uint8_t);
319
320
  }
}