"vscode:/vscode.git/clone" did not exist on "82cabf53a32be91ec08f214e97de06b99d0eef18"
cache.cpp 5.37 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
#include <map>
#include <vector>

#include "cpu_types.hpp"

namespace {
template <typename scalar_t>
void copy_blocks_cpu_impl(
    std::vector<torch::Tensor> &key_caches,
    std::vector<torch::Tensor> &value_caches,
11
    const torch::Tensor& mapping_pairs,
12
    const int element_num_per_block, const int layer_num) {
13
  const size_t pair_num = mapping_pairs.size(0);
14
15
16
17
  const size_t block_bytes = sizeof(scalar_t) * element_num_per_block;
#pragma omp parallel for collapse(2)
  for (int layer = 0; layer < layer_num; ++layer) {
    for (size_t pair = 0; pair < pair_num; ++pair) {
18
      int64_t source_offset = element_num_per_block * mapping_pairs[pair][0].item<int64_t>();
19
      int64_t target_offset =
20
          element_num_per_block * mapping_pairs[pair][1].item<int64_t>();
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
      scalar_t *key_cache_ptr = key_caches[layer].data_ptr<scalar_t>();
      scalar_t *source_ptr = key_cache_ptr + source_offset;
      scalar_t *target_ptr = key_cache_ptr + target_offset;
      std::memcpy(target_ptr, source_ptr, block_bytes);

      scalar_t *value_cache_ptr = value_caches[layer].data_ptr<scalar_t>();
      source_ptr = value_cache_ptr + source_offset;
      target_ptr = value_cache_ptr + target_offset;
      std::memcpy(target_ptr, source_ptr, block_bytes);
    }
  }
}

template <typename scalar_t>
void reshape_and_cache_cpu_impl(
    const scalar_t *__restrict__ key, const scalar_t *__restrict__ value,
    scalar_t *__restrict__ key_cache, scalar_t *__restrict__ value_cache,
    const int64_t *__restrict__ slot_mapping, const int num_tokens,
    const int key_stride, const int value_stride, const int num_heads,
    const int head_size, const int block_size, const int x) {
  const int block_elem_num = num_heads * head_size * block_size;

#pragma omp parallel for collapse(2)
  for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
    for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
      const int64_t slot_idx = slot_mapping[token_idx];
      if (slot_idx >= 0) {
        int src_key_head_idx = token_idx * key_stride + head_idx * head_size;
        int src_value_head_idx =
            token_idx * value_stride + head_idx * head_size;
        const scalar_t *src_key_head_ptr = key + src_key_head_idx;
        const scalar_t *src_value_head_ptr = value + src_value_head_idx;
        const int64_t block_index = slot_idx / block_size;
        const int64_t block_offset = slot_idx % block_size;
        scalar_t *target_key_head_ptr = key_cache +
                                        block_elem_num * block_index +
                                        head_idx * block_size * head_size;
        scalar_t *target_value_head_ptr = value_cache +
                                          block_elem_num * block_index +
                                          head_idx * block_size * head_size;

        for (int src_key_idx = 0; src_key_idx < head_size; src_key_idx += x) {
          const int64_t target_offset =
              src_key_idx * block_size + block_offset * x;
          for (int i = 0; i < x; ++i) {
            target_key_head_ptr[target_offset + i] =
                src_key_head_ptr[src_key_idx + i];
          }
        }

        for (int src_value_idx = 0; src_value_idx < head_size;
             ++src_value_idx) {
          const int64_t target_offset =
              src_value_idx * block_size + block_offset;
          target_value_head_ptr[target_offset] =
              src_value_head_ptr[src_value_idx];
        }
      }
    }
  }
}
}; // namespace

void copy_blocks(std::vector<torch::Tensor> &key_caches,
                 std::vector<torch::Tensor> &value_caches,
86
                 const torch::Tensor& block_mapping) {
87
88
89
90
91
92
93
94
95
96
  int num_layers = key_caches.size();
  TORCH_CHECK(num_layers == value_caches.size());
  if (num_layers == 0) {
    return;
  }

  const int element_num_per_block = key_caches[0][0].numel();
  VLLM_DISPATCH_FLOATING_TYPES(
      key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] {
        CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl)
97
        copy_blocks_cpu_impl<scalar_t>(key_caches, value_caches, block_mapping,
98
99
100
101
102
103
104
105
                                       element_num_per_block, num_layers);
        CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl)
      });
}

void reshape_and_cache(torch::Tensor &key, torch::Tensor &value,
                       torch::Tensor &key_cache, torch::Tensor &value_cache,
                       torch::Tensor &slot_mapping,
106
107
108
                       const std::string &kv_cache_dtype, float kv_scale) {
  TORCH_CHECK(kv_scale == 1.0f);

109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
  int num_tokens = key.size(0);
  int num_heads = key.size(1);
  int head_size = key.size(2);
  int block_size = key_cache.size(3);
  int x = key_cache.size(4);

  int key_stride = key.stride(0);
  int value_stride = value.stride(0);

  VLLM_DISPATCH_FLOATING_TYPES(
      key.scalar_type(), "reshape_and_cache_cpu_impl", [&] {
        CPU_KERNEL_GUARD_IN(reshape_and_cache_cpu_impl)
        reshape_and_cache_cpu_impl<scalar_t>(
            key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
            key_cache.data_ptr<scalar_t>(), value_cache.data_ptr<scalar_t>(),
            slot_mapping.data_ptr<int64_t>(), num_tokens, key_stride,
            value_stride, num_heads, head_size, block_size, x);
        CPU_KERNEL_GUARD_OUT(reshape_and_cache_cpu_impl)
      });
}

void swap_blocks(torch::Tensor &src, torch::Tensor &dst,
131
                 const torch::Tensor&block_mapping) {
132
133
  TORCH_CHECK(false, "swap_blocks is unsupported on CPU.")
}