cache.cpp 5.77 KB
Newer Older
1
2
3
4
5
#include <map>
#include <vector>

#include "cpu_types.hpp"

6
7
8
9
10
11
#if defined(__x86_64__)
  #define DISPATCH_MACRO VLLM_DISPATCH_FLOATING_TYPES_WITH_E5M2
#else
  #define DISPATCH_MACRO VLLM_DISPATCH_FLOATING_TYPES
#endif

12
13
namespace {
template <typename scalar_t>
14
15
void copy_blocks_cpu_impl(std::vector<torch::Tensor> const& key_caches,
                          std::vector<torch::Tensor> const& value_caches,
16
17
18
                          const torch::Tensor& mapping_pairs,
                          const int element_num_per_block,
                          const int layer_num) {
19
  const size_t pair_num = mapping_pairs.size(0);
20
21
22
23
  const size_t block_bytes = sizeof(scalar_t) * element_num_per_block;
#pragma omp parallel for collapse(2)
  for (int layer = 0; layer < layer_num; ++layer) {
    for (size_t pair = 0; pair < pair_num; ++pair) {
24
25
      int64_t source_offset =
          element_num_per_block * mapping_pairs[pair][0].item<int64_t>();
26
      int64_t target_offset =
27
          element_num_per_block * mapping_pairs[pair][1].item<int64_t>();
28
29
30
      scalar_t* key_cache_ptr = key_caches[layer].data_ptr<scalar_t>();
      scalar_t* source_ptr = key_cache_ptr + source_offset;
      scalar_t* target_ptr = key_cache_ptr + target_offset;
31
32
      std::memcpy(target_ptr, source_ptr, block_bytes);

33
      scalar_t* value_cache_ptr = value_caches[layer].data_ptr<scalar_t>();
34
35
36
37
38
39
40
41
42
      source_ptr = value_cache_ptr + source_offset;
      target_ptr = value_cache_ptr + target_offset;
      std::memcpy(target_ptr, source_ptr, block_bytes);
    }
  }
}

template <typename scalar_t>
void reshape_and_cache_cpu_impl(
43
44
45
    const scalar_t* __restrict__ key, const scalar_t* __restrict__ value,
    scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache,
    const int64_t* __restrict__ slot_mapping, const int num_tokens,
46
47
48
49
50
51
52
53
54
55
56
57
    const int key_stride, const int value_stride, const int num_heads,
    const int head_size, const int block_size, const int x) {
  const int block_elem_num = num_heads * head_size * block_size;

#pragma omp parallel for collapse(2)
  for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
    for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
      const int64_t slot_idx = slot_mapping[token_idx];
      if (slot_idx >= 0) {
        int src_key_head_idx = token_idx * key_stride + head_idx * head_size;
        int src_value_head_idx =
            token_idx * value_stride + head_idx * head_size;
58
59
        const scalar_t* src_key_head_ptr = key + src_key_head_idx;
        const scalar_t* src_value_head_ptr = value + src_value_head_idx;
60
61
        const int64_t block_index = slot_idx / block_size;
        const int64_t block_offset = slot_idx % block_size;
62
        scalar_t* target_key_head_ptr = key_cache +
63
64
                                        block_elem_num * block_index +
                                        head_idx * block_size * head_size;
65
        scalar_t* target_value_head_ptr = value_cache +
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
                                          block_elem_num * block_index +
                                          head_idx * block_size * head_size;

        for (int src_key_idx = 0; src_key_idx < head_size; src_key_idx += x) {
          const int64_t target_offset =
              src_key_idx * block_size + block_offset * x;
          for (int i = 0; i < x; ++i) {
            target_key_head_ptr[target_offset + i] =
                src_key_head_ptr[src_key_idx + i];
          }
        }

        for (int src_value_idx = 0; src_value_idx < head_size;
             ++src_value_idx) {
          const int64_t target_offset =
              src_value_idx * block_size + block_offset;
          target_value_head_ptr[target_offset] =
              src_value_head_ptr[src_value_idx];
        }
      }
    }
  }
}
89
};  // namespace
90

91
92
93
94
95
// Note: the key_caches and value_caches vectors are constant but
// not the Tensors they contain. The vectors need to be const refs
// in order to satisfy pytorch's C++ operator registration code.
void copy_blocks(std::vector<torch::Tensor> const& key_caches,
                 std::vector<torch::Tensor> const& value_caches,
96
                 const torch::Tensor& block_mapping) {
97
  unsigned num_layers = key_caches.size();
98
99
100
101
102
103
  TORCH_CHECK(num_layers == value_caches.size());
  if (num_layers == 0) {
    return;
  }

  const int element_num_per_block = key_caches[0][0].numel();
104
105
106
107
108
109
  DISPATCH_MACRO(key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] {
    CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl)
    copy_blocks_cpu_impl<scalar_t>(key_caches, value_caches, block_mapping,
                                   element_num_per_block, num_layers);
    CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl)
  });
110
111
}

112
113
114
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
                       torch::Tensor& key_cache, torch::Tensor& value_cache,
                       torch::Tensor& slot_mapping,
115
116
                       const std::string& kv_cache_dtype,
                       torch::Tensor& k_scale, torch::Tensor& v_scale) {
117
118
119
120
121
122
123
124
125
  int num_tokens = key.size(0);
  int num_heads = key.size(1);
  int head_size = key.size(2);
  int block_size = key_cache.size(3);
  int x = key_cache.size(4);

  int key_stride = key.stride(0);
  int value_stride = value.stride(0);

126
127
128
129
130
131
132
133
134
  DISPATCH_MACRO(key.scalar_type(), "reshape_and_cache_cpu_impl", [&] {
    CPU_KERNEL_GUARD_IN(reshape_and_cache_cpu_impl)
    reshape_and_cache_cpu_impl<scalar_t>(
        key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
        key_cache.data_ptr<scalar_t>(), value_cache.data_ptr<scalar_t>(),
        slot_mapping.data_ptr<int64_t>(), num_tokens, key_stride, value_stride,
        num_heads, head_size, block_size, x);
    CPU_KERNEL_GUARD_OUT(reshape_and_cache_cpu_impl)
  });
135
136
}

137
138
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
                 const torch::Tensor& block_mapping) {
139
140
  TORCH_CHECK(false, "swap_blocks is unsupported on CPU.")
}