cache_kernels.cu 5.35 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
3
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>

Woosuk Kwon's avatar
Woosuk Kwon committed
4
#include <algorithm>
Woosuk Kwon's avatar
Woosuk Kwon committed
5
6
#include <cassert>
#include <map>
7
#include <vector>
Woosuk Kwon's avatar
Woosuk Kwon committed
8

9
void swap_blocks(
Woosuk Kwon's avatar
Woosuk Kwon committed
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
  torch::Tensor& src,
  torch::Tensor& dst,
  const std::map<int64_t, int64_t>& block_mapping) {
  torch::Device src_device = src.device();
  torch::Device dst_device = dst.device();
  cudaMemcpyKind memcpy_type;
  if (src_device.is_cuda() && dst_device.is_cuda()) {
    assert(src_device.index() == dst_device.index());
    memcpy_type = cudaMemcpyDeviceToDevice;
  } else if (src_device.is_cuda() && dst_device.is_cpu()) {
    memcpy_type = cudaMemcpyDeviceToHost;
  } else if (src_device.is_cpu() && dst_device.is_cuda()) {
    memcpy_type = cudaMemcpyHostToDevice;
  } else {
    assert(false);
  }

  void *src_ptr = src.data_ptr();
  void *dst_ptr = dst.data_ptr();

  const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  for (const auto& pair : block_mapping) {
    int64_t src_block_number = pair.first;
    int64_t dst_block_number = pair.second;
    int64_t src_offset = src_block_number * block_size_in_bytes;
    int64_t dst_offset = dst_block_number * block_size_in_bytes;
    cudaMemcpyAsync(
      dst_ptr + dst_offset,
      src_ptr + src_offset,
      block_size_in_bytes,
      memcpy_type,
      stream);
  }
}
Woosuk Kwon's avatar
Woosuk Kwon committed
45

46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
void copy_blocks(
  torch::Tensor& src,
  torch::Tensor& dst,
  const std::map<int64_t, std::vector<int64_t>>& block_mapping) {
  torch::Device src_device = src.device();
  torch::Device dst_device = dst.device();
  assert(src_device.is_cuda() && dst_device.is_cuda());
  cudaMemcpyKind memcpy_type = cudaMemcpyDeviceToDevice;

  void *src_ptr = src.data_ptr();
  void *dst_ptr = dst.data_ptr();

  const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  for (const auto& pair : block_mapping) {
    int64_t src_block_number = pair.first;
    for (int64_t dst_block_number : pair.second) {
      int64_t src_offset = src_block_number * block_size_in_bytes;
      int64_t dst_offset = dst_block_number * block_size_in_bytes;
      cudaMemcpyAsync(
        dst_ptr + dst_offset,
        src_ptr + src_offset,
        block_size_in_bytes,
        memcpy_type,
        stream);
    }
  }
}

Woosuk Kwon's avatar
Woosuk Kwon committed
75
76
namespace cacheflow {

Woosuk Kwon's avatar
Woosuk Kwon committed
77
78
79
80
81
template<typename scalar_t>
__global__ void reshape_and_cache_kernel(
  const scalar_t* __restrict__ key,     // [num_tokens, num_heads, head_size]
  const scalar_t* __restrict__ value,   // [num_tokens, num_heads, head_size]
  scalar_t* __restrict__ key_cache,     // [num_blocks, num_heads, head_size/x, block_size, x]
82
  scalar_t* __restrict__ value_cache,   // [num_blocks, num_heads, head_size, block_size]
Woosuk Kwon's avatar
Woosuk Kwon committed
83
  const int* __restrict__ slot_mapping, // [num_tokens]
Woosuk Kwon's avatar
Woosuk Kwon committed
84
85
  const int key_stride,
  const int value_stride,
Woosuk Kwon's avatar
Woosuk Kwon committed
86
87
88
89
90
91
92
93
94
95
96
  const int num_heads,
  const int head_size,
  const int block_size,
  const int x) {
  const int token_idx = blockIdx.x;
  const int slot_idx = slot_mapping[token_idx];
  const int block_idx = slot_idx / block_size;
  const int block_offset = slot_idx % block_size;

  const int n = num_heads * head_size;
  for (int i = threadIdx.x; i < n; i += blockDim.x) {
Woosuk Kwon's avatar
Woosuk Kwon committed
97
98
    const int src_key_idx = token_idx * key_stride + i;
    const int src_value_idx = token_idx * value_stride + i;
Woosuk Kwon's avatar
Woosuk Kwon committed
99
100
101
102
103
104
105
106
107
108
109

    const int head_idx = i / head_size;
    const int head_offset = i % head_size;
    const int x_idx = head_offset / x;
    const int x_offset = head_offset % x;

    const int tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
                            + head_idx * (head_size / x) * block_size * x
                            + x_idx * block_size * x
                            + block_offset * x
                            + x_offset;
110
111
112
113
    const int tgt_value_idx = block_idx * num_heads * head_size * block_size
                              + head_idx * head_size * block_size
                              + head_offset * block_size
                              + block_offset;
Woosuk Kwon's avatar
Woosuk Kwon committed
114
115
    key_cache[tgt_key_idx] = __ldg(&key[src_key_idx]);
    value_cache[tgt_value_idx] = __ldg(&value[src_value_idx]);
Woosuk Kwon's avatar
Woosuk Kwon committed
116
117
118
  }
}

Woosuk Kwon's avatar
Woosuk Kwon committed
119
120
} // namespace cacheflow

Woosuk Kwon's avatar
Woosuk Kwon committed
121
void reshape_and_cache(
Woosuk Kwon's avatar
Woosuk Kwon committed
122
123
124
125
126
127
  torch::Tensor& key,           // [num_tokens, num_heads, head_size]
  torch::Tensor& value,         // [num_tokens, num_heads, head_size]
  torch::Tensor& key_cache,     // [num_blocks, num_heads, head_size/x, block_size, x]
  torch::Tensor& value_cache,   // [num_blocks, num_heads, head_size, block_size]
  torch::Tensor& slot_mapping)  // [num_tokens]
{
Woosuk Kwon's avatar
Woosuk Kwon committed
128
  int num_tokens = key.size(0);
129
  int num_heads = key.size(1);
Woosuk Kwon's avatar
Woosuk Kwon committed
130
131
132
133
  int head_size = key.size(2);
  int block_size = key_cache.size(3);
  int x = key_cache.size(4);

Woosuk Kwon's avatar
Woosuk Kwon committed
134
135
136
  int key_stride = key.stride(0);
  int value_stride = value.stride(0);

Woosuk Kwon's avatar
Woosuk Kwon committed
137
  dim3 grid(num_tokens);
138
  dim3 block(std::min(num_heads * head_size, 512));
Woosuk Kwon's avatar
Woosuk Kwon committed
139
140
141
142
143
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
    key.scalar_type(),
    "reshape_and_cache_kernel",
    [&] {
Woosuk Kwon's avatar
Woosuk Kwon committed
144
      cacheflow::reshape_and_cache_kernel<scalar_t><<<grid, block, 0, stream>>>(
Woosuk Kwon's avatar
Woosuk Kwon committed
145
146
147
148
149
        key.data_ptr<scalar_t>(),
        value.data_ptr<scalar_t>(),
        key_cache.data_ptr<scalar_t>(),
        value_cache.data_ptr<scalar_t>(),
        slot_mapping.data_ptr<int>(),
Woosuk Kwon's avatar
Woosuk Kwon committed
150
151
        key_stride,
        value_stride,
152
        num_heads,
Woosuk Kwon's avatar
Woosuk Kwon committed
153
154
155
156
157
        head_size,
        block_size,
        x);
    });
}