cache_kernels.cu 7.83 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
3
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>

Woosuk Kwon's avatar
Woosuk Kwon committed
4
#include <algorithm>
Woosuk Kwon's avatar
Woosuk Kwon committed
5
6
#include <cassert>
#include <map>
7
#include <vector>
Woosuk Kwon's avatar
Woosuk Kwon committed
8

9
void swap_blocks(
Woosuk Kwon's avatar
Woosuk Kwon committed
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
  torch::Tensor& src,
  torch::Tensor& dst,
  const std::map<int64_t, int64_t>& block_mapping) {
  torch::Device src_device = src.device();
  torch::Device dst_device = dst.device();
  cudaMemcpyKind memcpy_type;
  if (src_device.is_cuda() && dst_device.is_cuda()) {
    assert(src_device.index() == dst_device.index());
    memcpy_type = cudaMemcpyDeviceToDevice;
  } else if (src_device.is_cuda() && dst_device.is_cpu()) {
    memcpy_type = cudaMemcpyDeviceToHost;
  } else if (src_device.is_cpu() && dst_device.is_cuda()) {
    memcpy_type = cudaMemcpyHostToDevice;
  } else {
    assert(false);
  }

  void *src_ptr = src.data_ptr();
  void *dst_ptr = dst.data_ptr();

  const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  for (const auto& pair : block_mapping) {
    int64_t src_block_number = pair.first;
    int64_t dst_block_number = pair.second;
    int64_t src_offset = src_block_number * block_size_in_bytes;
    int64_t dst_offset = dst_block_number * block_size_in_bytes;
    cudaMemcpyAsync(
      dst_ptr + dst_offset,
      src_ptr + src_offset,
      block_size_in_bytes,
      memcpy_type,
      stream);
  }
}
Woosuk Kwon's avatar
Woosuk Kwon committed
45

46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
namespace cacheflow {

// Grid: (num_layers, num_pairs)
template<typename scalar_t>
__global__ void copy_blocks_kernel(
  int64_t* key_cache_ptrs,
  int64_t* value_cache_ptrs,
  const int* __restrict__ block_mapping,
  const int numel_per_block) {
  const int layer_idx = blockIdx.x;
  const int pair_idx = blockIdx.y;

  scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]);
  scalar_t* value_cache = reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
  int src_block_number = block_mapping[2 * pair_idx];
  int dst_block_number = block_mapping[2 * pair_idx + 1];

  const int src_block_offset = src_block_number * numel_per_block;
  const int dst_block_offset = dst_block_number * numel_per_block;
  for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
    int src_offset = src_block_offset + i;
    int dst_offset = dst_block_offset + i;
    key_cache[dst_offset] = key_cache[src_offset];
  }
  for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
    int src_offset = src_block_offset + i;
    int dst_offset = dst_block_offset + i;
    value_cache[dst_offset] = value_cache[src_offset];
  }
}

} // namespace cacheflow

79
void copy_blocks(
80
81
  std::vector<torch::Tensor>& key_caches,
  std::vector<torch::Tensor>& value_caches,
82
  const std::map<int64_t, std::vector<int64_t>>& block_mapping) {
83
84
85
86
87
88
89
  int num_layers = key_caches.size();
  TORCH_CHECK(num_layers == value_caches.size());
  if (num_layers == 0) {
    return;
  }
  torch::Device cache_device = key_caches[0].device();
  TORCH_CHECK(cache_device.is_cuda());
90

91
92
93
94
95
96
97
98
99
100
  // Create data structures for the kernel.
  // Create an array of pointers to the key and value caches.
  int64_t key_cache_ptrs[num_layers];
  int64_t value_cache_ptrs[num_layers];
  for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
    key_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
    value_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
  }
  // Create block mapping array.
  std::vector<int> block_mapping_vec;
101
  for (const auto& pair : block_mapping) {
102
103
104
105
    int src_block_number = pair.first;
    for (int dst_block_number : pair.second) {
      block_mapping_vec.push_back(src_block_number);
      block_mapping_vec.push_back(dst_block_number);
106
107
    }
  }
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
  int* block_mapping_array = block_mapping_vec.data();
  int num_pairs = block_mapping_vec.size() / 2;

  // Move the data structures to the GPU.
  // NOTE: This synchronizes the CPU and GPU.
  torch::Tensor key_cache_ptrs_tensor = torch::from_blob(
    key_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
  torch::Tensor value_cache_ptrs_tensor = torch::from_blob(
    value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
  torch::Tensor block_mapping_tensor = torch::from_blob(
    block_mapping_array, {2 * num_pairs}, torch::kInt).to(cache_device);

  // Launch the kernel.
  const int numel_per_block = key_caches[0][0].numel();
  dim3 grid(num_layers, num_pairs);
  dim3 block(std::min(1024, numel_per_block));
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
    key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
      cacheflow::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
        key_cache_ptrs_tensor.data_ptr<int64_t>(),
        value_cache_ptrs_tensor.data_ptr<int64_t>(),
        block_mapping_tensor.data_ptr<int>(),
        numel_per_block);
    }));
133
134
}

Woosuk Kwon's avatar
Woosuk Kwon committed
135
136
namespace cacheflow {

Woosuk Kwon's avatar
Woosuk Kwon committed
137
138
139
140
141
template<typename scalar_t>
__global__ void reshape_and_cache_kernel(
  const scalar_t* __restrict__ key,     // [num_tokens, num_heads, head_size]
  const scalar_t* __restrict__ value,   // [num_tokens, num_heads, head_size]
  scalar_t* __restrict__ key_cache,     // [num_blocks, num_heads, head_size/x, block_size, x]
142
  scalar_t* __restrict__ value_cache,   // [num_blocks, num_heads, head_size, block_size]
Woosuk Kwon's avatar
Woosuk Kwon committed
143
  const int* __restrict__ slot_mapping, // [num_tokens]
Woosuk Kwon's avatar
Woosuk Kwon committed
144
145
  const int key_stride,
  const int value_stride,
Woosuk Kwon's avatar
Woosuk Kwon committed
146
147
148
149
150
151
152
153
154
155
156
  const int num_heads,
  const int head_size,
  const int block_size,
  const int x) {
  const int token_idx = blockIdx.x;
  const int slot_idx = slot_mapping[token_idx];
  const int block_idx = slot_idx / block_size;
  const int block_offset = slot_idx % block_size;

  const int n = num_heads * head_size;
  for (int i = threadIdx.x; i < n; i += blockDim.x) {
Woosuk Kwon's avatar
Woosuk Kwon committed
157
158
    const int src_key_idx = token_idx * key_stride + i;
    const int src_value_idx = token_idx * value_stride + i;
Woosuk Kwon's avatar
Woosuk Kwon committed
159
160
161
162
163
164
165
166
167
168
169

    const int head_idx = i / head_size;
    const int head_offset = i % head_size;
    const int x_idx = head_offset / x;
    const int x_offset = head_offset % x;

    const int tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
                            + head_idx * (head_size / x) * block_size * x
                            + x_idx * block_size * x
                            + block_offset * x
                            + x_offset;
170
171
172
173
    const int tgt_value_idx = block_idx * num_heads * head_size * block_size
                              + head_idx * head_size * block_size
                              + head_offset * block_size
                              + block_offset;
Woosuk Kwon's avatar
Woosuk Kwon committed
174
175
    key_cache[tgt_key_idx] = __ldg(&key[src_key_idx]);
    value_cache[tgt_value_idx] = __ldg(&value[src_value_idx]);
Woosuk Kwon's avatar
Woosuk Kwon committed
176
177
178
  }
}

Woosuk Kwon's avatar
Woosuk Kwon committed
179
180
} // namespace cacheflow

Woosuk Kwon's avatar
Woosuk Kwon committed
181
void reshape_and_cache(
Woosuk Kwon's avatar
Woosuk Kwon committed
182
183
184
185
186
187
  torch::Tensor& key,           // [num_tokens, num_heads, head_size]
  torch::Tensor& value,         // [num_tokens, num_heads, head_size]
  torch::Tensor& key_cache,     // [num_blocks, num_heads, head_size/x, block_size, x]
  torch::Tensor& value_cache,   // [num_blocks, num_heads, head_size, block_size]
  torch::Tensor& slot_mapping)  // [num_tokens]
{
Woosuk Kwon's avatar
Woosuk Kwon committed
188
  int num_tokens = key.size(0);
189
  int num_heads = key.size(1);
Woosuk Kwon's avatar
Woosuk Kwon committed
190
191
192
193
  int head_size = key.size(2);
  int block_size = key_cache.size(3);
  int x = key_cache.size(4);

Woosuk Kwon's avatar
Woosuk Kwon committed
194
195
196
  int key_stride = key.stride(0);
  int value_stride = value.stride(0);

Woosuk Kwon's avatar
Woosuk Kwon committed
197
  dim3 grid(num_tokens);
198
  dim3 block(std::min(num_heads * head_size, 512));
Woosuk Kwon's avatar
Woosuk Kwon committed
199
200
201
202
203
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
    key.scalar_type(),
    "reshape_and_cache_kernel",
    [&] {
Woosuk Kwon's avatar
Woosuk Kwon committed
204
      cacheflow::reshape_and_cache_kernel<scalar_t><<<grid, block, 0, stream>>>(
Woosuk Kwon's avatar
Woosuk Kwon committed
205
206
207
208
209
        key.data_ptr<scalar_t>(),
        value.data_ptr<scalar_t>(),
        key_cache.data_ptr<scalar_t>(),
        value_cache.data_ptr<scalar_t>(),
        slot_mapping.data_ptr<int>(),
Woosuk Kwon's avatar
Woosuk Kwon committed
210
211
        key_stride,
        value_stride,
212
        num_heads,
Woosuk Kwon's avatar
Woosuk Kwon committed
213
214
215
216
217
        head_size,
        block_size,
        x);
    });
}