cache_kernels.cu 14.8 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
3
#include <c10/cuda/CUDAGuard.h>
Woosuk Kwon's avatar
Woosuk Kwon committed
4

5
#include "cuda_compat.h"
6
7
#include "dispatch_utils.h"

Woosuk Kwon's avatar
Woosuk Kwon committed
8
#include <algorithm>
Woosuk Kwon's avatar
Woosuk Kwon committed
9
10
#include <cassert>
#include <map>
11
#include <vector>
Woosuk Kwon's avatar
Woosuk Kwon committed
12

13
void swap_blocks(
Woosuk Kwon's avatar
Woosuk Kwon committed
14
15
16
17
18
19
20
  torch::Tensor& src,
  torch::Tensor& dst,
  const std::map<int64_t, int64_t>& block_mapping) {
  torch::Device src_device = src.device();
  torch::Device dst_device = dst.device();
  cudaMemcpyKind memcpy_type;
  if (src_device.is_cuda() && dst_device.is_cuda()) {
Woosuk Kwon's avatar
Woosuk Kwon committed
21
22
23
    TORCH_CHECK(
      src_device.index() == dst_device.index(),
      "src and dst must be on the same GPU");
Woosuk Kwon's avatar
Woosuk Kwon committed
24
25
26
27
28
29
    memcpy_type = cudaMemcpyDeviceToDevice;
  } else if (src_device.is_cuda() && dst_device.is_cpu()) {
    memcpy_type = cudaMemcpyDeviceToHost;
  } else if (src_device.is_cpu() && dst_device.is_cuda()) {
    memcpy_type = cudaMemcpyHostToDevice;
  } else {
Woosuk Kwon's avatar
Woosuk Kwon committed
30
    TORCH_CHECK(false, "Invalid device combination");
Woosuk Kwon's avatar
Woosuk Kwon committed
31
32
  }

33
34
  char *src_ptr = static_cast<char*>(src.data_ptr());
  char *dst_ptr = static_cast<char*>(dst.data_ptr());
Woosuk Kwon's avatar
Woosuk Kwon committed
35
36

  const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
37
  const at::cuda::OptionalCUDAGuard device_guard(src_device);
Woosuk Kwon's avatar
Woosuk Kwon committed
38
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
Woosuk Kwon's avatar
Woosuk Kwon committed
39
  // NOTE(woosuk): This can be slow if the number of blocks is large.
Woosuk Kwon's avatar
Woosuk Kwon committed
40
41
42
43
44
45
46
47
48
49
50
51
52
  for (const auto& pair : block_mapping) {
    int64_t src_block_number = pair.first;
    int64_t dst_block_number = pair.second;
    int64_t src_offset = src_block_number * block_size_in_bytes;
    int64_t dst_offset = dst_block_number * block_size_in_bytes;
    cudaMemcpyAsync(
      dst_ptr + dst_offset,
      src_ptr + src_offset,
      block_size_in_bytes,
      memcpy_type,
      stream);
  }
}
Woosuk Kwon's avatar
Woosuk Kwon committed
53

Woosuk Kwon's avatar
Woosuk Kwon committed
54
namespace vllm {
55
56
57
58
59
60

// Grid: (num_layers, num_pairs)
template<typename scalar_t>
__global__ void copy_blocks_kernel(
  int64_t* key_cache_ptrs,
  int64_t* value_cache_ptrs,
61
  const int64_t* __restrict__ block_mapping,
62
63
64
65
66
67
  const int numel_per_block) {
  const int layer_idx = blockIdx.x;
  const int pair_idx = blockIdx.y;

  scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]);
  scalar_t* value_cache = reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
68
69
  int64_t src_block_number = block_mapping[2 * pair_idx];
  int64_t dst_block_number = block_mapping[2 * pair_idx + 1];
70

71
72
  const int64_t src_block_offset = src_block_number * numel_per_block;
  const int64_t dst_block_offset = dst_block_number * numel_per_block;
73
  for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
74
75
    int64_t src_offset = src_block_offset + i;
    int64_t dst_offset = dst_block_offset + i;
76
77
78
    key_cache[dst_offset] = key_cache[src_offset];
  }
  for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
79
80
    int64_t src_offset = src_block_offset + i;
    int64_t dst_offset = dst_block_offset + i;
81
82
83
84
    value_cache[dst_offset] = value_cache[src_offset];
  }
}

Woosuk Kwon's avatar
Woosuk Kwon committed
85
} // namespace vllm
86

87
void copy_blocks(
88
89
  std::vector<torch::Tensor>& key_caches,
  std::vector<torch::Tensor>& value_caches,
90
  const std::map<int64_t, std::vector<int64_t>>& block_mapping) {
91
92
93
94
95
96
97
  int num_layers = key_caches.size();
  TORCH_CHECK(num_layers == value_caches.size());
  if (num_layers == 0) {
    return;
  }
  torch::Device cache_device = key_caches[0].device();
  TORCH_CHECK(cache_device.is_cuda());
98

99
100
101
102
103
104
105
106
107
  // Create data structures for the kernel.
  // Create an array of pointers to the key and value caches.
  int64_t key_cache_ptrs[num_layers];
  int64_t value_cache_ptrs[num_layers];
  for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
    key_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
    value_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
  }
  // Create block mapping array.
108
  std::vector<int64_t> block_mapping_vec;
109
  for (const auto& pair : block_mapping) {
110
111
    int64_t src_block_number = pair.first;
    for (int64_t dst_block_number : pair.second) {
112
113
      block_mapping_vec.push_back(src_block_number);
      block_mapping_vec.push_back(dst_block_number);
114
115
    }
  }
116
  int64_t* block_mapping_array = block_mapping_vec.data();
117
118
119
120
121
122
123
124
125
  int num_pairs = block_mapping_vec.size() / 2;

  // Move the data structures to the GPU.
  // NOTE: This synchronizes the CPU and GPU.
  torch::Tensor key_cache_ptrs_tensor = torch::from_blob(
    key_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
  torch::Tensor value_cache_ptrs_tensor = torch::from_blob(
    value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
  torch::Tensor block_mapping_tensor = torch::from_blob(
126
    block_mapping_array, {2 * num_pairs}, torch::kInt64).to(cache_device);
127
128
129
130
131

  // Launch the kernel.
  const int numel_per_block = key_caches[0][0].numel();
  dim3 grid(num_layers, num_pairs);
  dim3 block(std::min(1024, numel_per_block));
132
  const at::cuda::OptionalCUDAGuard device_guard(cache_device);
133
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
134
  VLLM_DISPATCH_FLOATING_TYPES(
135
    key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
Woosuk Kwon's avatar
Woosuk Kwon committed
136
      vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
137
138
        key_cache_ptrs_tensor.data_ptr<int64_t>(),
        value_cache_ptrs_tensor.data_ptr<int64_t>(),
139
        block_mapping_tensor.data_ptr<int64_t>(),
140
141
        numel_per_block);
    }));
142
143
}

Woosuk Kwon's avatar
Woosuk Kwon committed
144
namespace vllm {
Woosuk Kwon's avatar
Woosuk Kwon committed
145

Woosuk Kwon's avatar
Woosuk Kwon committed
146
147
template<typename scalar_t>
__global__ void reshape_and_cache_kernel(
148
149
150
151
152
  const scalar_t* __restrict__ key,           // [num_tokens, num_heads, head_size]
  const scalar_t* __restrict__ value,         // [num_tokens, num_heads, head_size]
  scalar_t* __restrict__ key_cache,           // [num_blocks, num_heads, head_size/x, block_size, x]
  scalar_t* __restrict__ value_cache,         // [num_blocks, num_heads, head_size, block_size]
  const int64_t* __restrict__ slot_mapping,   // [num_tokens]
Woosuk Kwon's avatar
Woosuk Kwon committed
153
154
  const int key_stride,
  const int value_stride,
Woosuk Kwon's avatar
Woosuk Kwon committed
155
156
157
158
  const int num_heads,
  const int head_size,
  const int block_size,
  const int x) {
159
160
  const int64_t token_idx = blockIdx.x;
  const int64_t slot_idx = slot_mapping[token_idx];
161
162
163
164
165
  if (slot_idx < 0) {
    // Padding token that should be ignored.
    return;
  }

166
167
  const int64_t block_idx = slot_idx / block_size;
  const int64_t block_offset = slot_idx % block_size;
Woosuk Kwon's avatar
Woosuk Kwon committed
168
169
170

  const int n = num_heads * head_size;
  for (int i = threadIdx.x; i < n; i += blockDim.x) {
171
172
    const int64_t src_key_idx = token_idx * key_stride + i;
    const int64_t src_value_idx = token_idx * value_stride + i;
Woosuk Kwon's avatar
Woosuk Kwon committed
173
174
175
176
177
178

    const int head_idx = i / head_size;
    const int head_offset = i % head_size;
    const int x_idx = head_offset / x;
    const int x_offset = head_offset % x;

179
180
181
182
183
184
185
186
187
    const int64_t tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
                                + head_idx * (head_size / x) * block_size * x
                                + x_idx * block_size * x
                                + block_offset * x
                                + x_offset;
    const int64_t tgt_value_idx = block_idx * num_heads * head_size * block_size
                                  + head_idx * head_size * block_size
                                  + head_offset * block_size
                                  + block_offset;
188
189
    key_cache[tgt_key_idx] = key[src_key_idx];
    value_cache[tgt_value_idx] = value[src_value_idx];
Woosuk Kwon's avatar
Woosuk Kwon committed
190
191
192
  }
}

Woosuk Kwon's avatar
Woosuk Kwon committed
193
} // namespace vllm
Woosuk Kwon's avatar
Woosuk Kwon committed
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212

void reshape_and_cache(
  torch::Tensor& key,           // [num_tokens, num_heads, head_size]
  torch::Tensor& value,         // [num_tokens, num_heads, head_size]
  torch::Tensor& key_cache,     // [num_blocks, num_heads, head_size/x, block_size, x]
  torch::Tensor& value_cache,   // [num_blocks, num_heads, head_size, block_size]
  torch::Tensor& slot_mapping)  // [num_tokens]
{
  int num_tokens = key.size(0);
  int num_heads = key.size(1);
  int head_size = key.size(2);
  int block_size = key_cache.size(3);
  int x = key_cache.size(4);

  int key_stride = key.stride(0);
  int value_stride = value.stride(0);

  dim3 grid(num_tokens);
  dim3 block(std::min(num_heads * head_size, 512));
213
  const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
Woosuk Kwon's avatar
Woosuk Kwon committed
214
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
215
  VLLM_DISPATCH_FLOATING_TYPES(
Woosuk Kwon's avatar
Woosuk Kwon committed
216
217
218
    key.scalar_type(),
    "reshape_and_cache_kernel",
    [&] {
Woosuk Kwon's avatar
Woosuk Kwon committed
219
      vllm::reshape_and_cache_kernel<scalar_t><<<grid, block, 0, stream>>>(
Woosuk Kwon's avatar
Woosuk Kwon committed
220
221
222
223
        key.data_ptr<scalar_t>(),
        value.data_ptr<scalar_t>(),
        key_cache.data_ptr<scalar_t>(),
        value_cache.data_ptr<scalar_t>(),
224
        slot_mapping.data_ptr<int64_t>(),
Woosuk Kwon's avatar
Woosuk Kwon committed
225
226
227
228
229
230
231
232
233
        key_stride,
        value_stride,
        num_heads,
        head_size,
        block_size,
        x);
    });
}

Woosuk Kwon's avatar
Woosuk Kwon committed
234
namespace vllm {
Woosuk Kwon's avatar
Woosuk Kwon committed
235

236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
// Grid: (num_blocks, block_size).
template<typename scalar_t>
__global__ void gather_cached_kv_kernel(
  scalar_t* __restrict__ key,             // [num_tokens, [stride], num_heads, head_size]
  scalar_t* __restrict__ value,           // [num_tokens, [stride], num_heads, head_size]
  const scalar_t* __restrict__ key_cache,   // [num_blocks, num_heads, head_size/x, block_size, x]
  const scalar_t* __restrict__ value_cache,   // [num_blocks, num_heads, head_size, block_size]
  const int* __restrict__ slot_mapping,   // [num_tokens]
  const int key_stride,
  const int value_stride,
  const int num_heads,
  const int head_size,
  const int block_size,
  const int x) {
    const int token_idx = blockIdx.x;
    const int slot_idx = slot_mapping[token_idx];
    const int block_idx = slot_idx / block_size;
    const int block_offset = slot_idx % block_size;

    const int num_tokens = num_heads * head_size;
    for (int i = threadIdx.x; i < num_tokens; i += blockDim.x) {
      const int tgt_key_idx = token_idx * key_stride + i;
      const int tgt_value_idx = token_idx * value_stride + i;
  
      const int head_idx = i / head_size;
      const int head_offset = i % head_size;
      const int x_idx = head_offset / x;  // the offset of the [head_size/x] dimension
      const int x_offset = head_offset % x;
  
      const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
                              + head_idx * (head_size / x) * block_size * x
                              + x_idx * block_size * x
                              + block_offset * x
                              + x_offset;
      const int src_value_idx = block_idx * num_heads * head_size * block_size
                                + head_idx * head_size * block_size
                                + head_offset * block_size
                                + block_offset;

275
276
      key[tgt_key_idx] = VLLM_LDG(&key_cache[src_key_idx]);
      value[tgt_value_idx] = VLLM_LDG(&value_cache[src_value_idx]);
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
    }
}

template <typename scalar_t>
__global__ void gather_cached_kv_kernel_optimized(
    scalar_t *__restrict__ key,             // [num_tokens, [stride], num_heads, head_size]
    scalar_t *__restrict__ value,           // [num_tokens, [stride], num_heads, head_size]
    const scalar_t *__restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
    const scalar_t *__restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
    const int *__restrict__ slot_mapping,   // [num_tokens]
    const int key_stride,
    const int value_stride,
    const int num_heads,
    const int head_size,
    const int block_size,
    const int x)
{
    const int token_idx = blockIdx.x;
    const int slot_idx = slot_mapping[token_idx];
    const int block_idx = slot_idx / block_size;
    const int block_offset = slot_idx % block_size;

    const int dim = num_heads * head_size;
    assert(dim % 4 == 0);  // this is true for known use cases
    const int unroll_factor = 4;
    const int unrolled_dim = dim / unroll_factor;

    for (int i = threadIdx.x; i < unrolled_dim; i += blockDim.x)
    {
        int tgt_key_indices[unroll_factor];
        int tgt_value_indices[unroll_factor];
        int src_key_indices[unroll_factor];
        int src_value_indices[unroll_factor];
        scalar_t keys_to_store[unroll_factor];
        scalar_t values_to_store[unroll_factor];

        #pragma unroll
        for (int j = 0; j < unroll_factor; ++j)
        {
            int index = i + j * unrolled_dim;

            const int tgt_key_idx = token_idx * key_stride + index;
            const int tgt_value_idx = token_idx * value_stride + index;

            const int head_idx = index / head_size;
            const int head_offset = index % head_size;
            const int x_idx = head_offset / x;
            const int x_offset = head_offset % x;

            const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
                                    + head_idx * (head_size / x) * block_size * x
                                    + x_idx * block_size * x
                                    + block_offset * x
                                    + x_offset;
            const int src_value_idx = block_idx * num_heads * head_size * block_size
                                      + head_idx * head_size * block_size
                                      + head_offset * block_size
                                      + block_offset;

            tgt_key_indices[j] = tgt_key_idx;
            tgt_value_indices[j] = tgt_value_idx;
            src_key_indices[j] = src_key_idx;
            src_value_indices[j] = src_value_idx;

341
342
            keys_to_store[j] = VLLM_LDG(&key_cache[src_key_idx]);
            values_to_store[j] = VLLM_LDG(&value_cache[src_value_idx]);
343
344
345
346
347
348
349
350
351
352
353
        }

        #pragma unroll
        for (int j = 0; j < unroll_factor; ++j)
        {
            key[tgt_key_indices[j]] = keys_to_store[j];
            value[tgt_value_indices[j]] = values_to_store[j];
        }
    }
}

Woosuk Kwon's avatar
Woosuk Kwon committed
354
} // namespace vllm
Woosuk Kwon's avatar
Woosuk Kwon committed
355

356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
void gather_cached_kv(
  torch::Tensor& key,           // [out] [num_tokens, num_heads, head_size]
  torch::Tensor& value,         // [out] [num_tokens, num_heads, head_size]
  torch::Tensor& key_cache,     // [in]  [num_blocks, num_heads, head_size/x, block_size, x]
  torch::Tensor& value_cache,   // [in]  [num_blocks, num_heads, head_size, block_size]
  torch::Tensor& slot_mapping)  // [in]  [num_tokens]
{
  int num_tokens = key.size(0);
  int num_heads = key.size(1);
  int head_size = key.size(2);
  int block_size = key_cache.size(3);
  int x = key_cache.size(4);

  int key_stride = key.stride(0);
  int value_stride = value.stride(0);

  dim3 grid(num_tokens);
  dim3 block(std::min(num_heads * head_size, 512));
374
  const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
375
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
376
  VLLM_DISPATCH_FLOATING_TYPES(
377
378
379
    key.scalar_type(),
    "gather_cached_kv_kernel_optimized",
    [&] {
Woosuk Kwon's avatar
Woosuk Kwon committed
380
      vllm::gather_cached_kv_kernel_optimized<scalar_t><<<grid, block, 0, stream>>>(
381
382
383
384
385
386
387
388
389
390
391
392
393
        key.data_ptr<scalar_t>(),
        value.data_ptr<scalar_t>(),
        key_cache.data_ptr<scalar_t>(),
        value_cache.data_ptr<scalar_t>(),
        slot_mapping.data_ptr<int>(),
        key_stride,
        value_stride,
        num_heads,
        head_size,
        block_size,
        x);
    });
}