cache.cpp 5.66 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
#include <map>
#include <vector>

#include "cpu_types.hpp"

namespace {
template <typename scalar_t>
void copy_blocks_cpu_impl(
    std::vector<torch::Tensor> &key_caches,
    std::vector<torch::Tensor> &value_caches,
    const std::vector<std::pair<int64_t, int64_t>> mapping_pairs,
    const int element_num_per_block, const int layer_num) {
  const size_t pair_num = mapping_pairs.size();
  const size_t block_bytes = sizeof(scalar_t) * element_num_per_block;
#pragma omp parallel for collapse(2)
  for (int layer = 0; layer < layer_num; ++layer) {
    for (size_t pair = 0; pair < pair_num; ++pair) {
      int64_t source_offset = element_num_per_block * mapping_pairs[pair].first;
      int64_t target_offset =
          element_num_per_block * mapping_pairs[pair].second;
      scalar_t *key_cache_ptr = key_caches[layer].data_ptr<scalar_t>();
      scalar_t *source_ptr = key_cache_ptr + source_offset;
      scalar_t *target_ptr = key_cache_ptr + target_offset;
      std::memcpy(target_ptr, source_ptr, block_bytes);

      scalar_t *value_cache_ptr = value_caches[layer].data_ptr<scalar_t>();
      source_ptr = value_cache_ptr + source_offset;
      target_ptr = value_cache_ptr + target_offset;
      std::memcpy(target_ptr, source_ptr, block_bytes);
    }
  }
}

template <typename scalar_t>
void reshape_and_cache_cpu_impl(
    const scalar_t *__restrict__ key, const scalar_t *__restrict__ value,
    scalar_t *__restrict__ key_cache, scalar_t *__restrict__ value_cache,
    const int64_t *__restrict__ slot_mapping, const int num_tokens,
    const int key_stride, const int value_stride, const int num_heads,
    const int head_size, const int block_size, const int x) {
  const int block_elem_num = num_heads * head_size * block_size;

#pragma omp parallel for collapse(2)
  for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
    for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
      const int64_t slot_idx = slot_mapping[token_idx];
      if (slot_idx >= 0) {
        int src_key_head_idx = token_idx * key_stride + head_idx * head_size;
        int src_value_head_idx =
            token_idx * value_stride + head_idx * head_size;
        const scalar_t *src_key_head_ptr = key + src_key_head_idx;
        const scalar_t *src_value_head_ptr = value + src_value_head_idx;
        const int64_t block_index = slot_idx / block_size;
        const int64_t block_offset = slot_idx % block_size;
        scalar_t *target_key_head_ptr = key_cache +
                                        block_elem_num * block_index +
                                        head_idx * block_size * head_size;
        scalar_t *target_value_head_ptr = value_cache +
                                          block_elem_num * block_index +
                                          head_idx * block_size * head_size;

        for (int src_key_idx = 0; src_key_idx < head_size; src_key_idx += x) {
          const int64_t target_offset =
              src_key_idx * block_size + block_offset * x;
          for (int i = 0; i < x; ++i) {
            target_key_head_ptr[target_offset + i] =
                src_key_head_ptr[src_key_idx + i];
          }
        }

        for (int src_value_idx = 0; src_value_idx < head_size;
             ++src_value_idx) {
          const int64_t target_offset =
              src_value_idx * block_size + block_offset;
          target_value_head_ptr[target_offset] =
              src_value_head_ptr[src_value_idx];
        }
      }
    }
  }
}
}; // namespace

void copy_blocks(std::vector<torch::Tensor> &key_caches,
                 std::vector<torch::Tensor> &value_caches,
                 const std::map<int64_t, std::vector<int64_t>> &block_mapping) {
  int num_layers = key_caches.size();
  TORCH_CHECK(num_layers == value_caches.size());
  if (num_layers == 0) {
    return;
  }

  std::vector<std::pair<int64_t, int64_t>> mapping_pairs;
  mapping_pairs.reserve(block_mapping.size());
  for (const auto &pair : block_mapping) {
    for (const auto &dst : pair.second) {
      mapping_pairs.emplace_back(pair.first, dst);
    }
  }

  const int element_num_per_block = key_caches[0][0].numel();
  VLLM_DISPATCH_FLOATING_TYPES(
      key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] {
        CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl)
        copy_blocks_cpu_impl<scalar_t>(key_caches, value_caches, mapping_pairs,
                                       element_num_per_block, num_layers);
        CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl)
      });
}

void reshape_and_cache(torch::Tensor &key, torch::Tensor &value,
                       torch::Tensor &key_cache, torch::Tensor &value_cache,
                       torch::Tensor &slot_mapping,
                       const std::string &kv_cache_dtype, float kv_scale) {
  TORCH_CHECK(kv_scale == 1.0f);

  int num_tokens = key.size(0);
  int num_heads = key.size(1);
  int head_size = key.size(2);
  int block_size = key_cache.size(3);
  int x = key_cache.size(4);

  int key_stride = key.stride(0);
  int value_stride = value.stride(0);

  VLLM_DISPATCH_FLOATING_TYPES(
      key.scalar_type(), "reshape_and_cache_cpu_impl", [&] {
        CPU_KERNEL_GUARD_IN(reshape_and_cache_cpu_impl)
        reshape_and_cache_cpu_impl<scalar_t>(
            key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
            key_cache.data_ptr<scalar_t>(), value_cache.data_ptr<scalar_t>(),
            slot_mapping.data_ptr<int64_t>(), num_tokens, key_stride,
            value_stride, num_heads, head_size, block_size, x);
        CPU_KERNEL_GUARD_OUT(reshape_and_cache_cpu_impl)
      });
}

void swap_blocks(torch::Tensor &src, torch::Tensor &dst,
                 const std::map<int64_t, int64_t> &block_mapping) {
  TORCH_CHECK(false, "swap_blocks is unsupported on CPU.")
}