cache.cpp 8.8 KB
Newer Older
1
2
3
4
5
#include <map>
#include <vector>

#include "cpu_types.hpp"

6
7
8
9
10
11
#if defined(__x86_64__)
  #define DISPATCH_MACRO VLLM_DISPATCH_FLOATING_TYPES_WITH_E5M2
#else
  #define DISPATCH_MACRO VLLM_DISPATCH_FLOATING_TYPES
#endif

12
13
namespace {
template <typename scalar_t>
14
15
void copy_blocks_cpu_impl(std::vector<torch::Tensor> const& key_caches,
                          std::vector<torch::Tensor> const& value_caches,
16
17
18
                          const torch::Tensor& mapping_pairs,
                          const int element_num_per_block,
                          const int layer_num) {
19
  const size_t pair_num = mapping_pairs.size(0);
20
21
22
23
  const size_t block_bytes = sizeof(scalar_t) * element_num_per_block;
#pragma omp parallel for collapse(2)
  for (int layer = 0; layer < layer_num; ++layer) {
    for (size_t pair = 0; pair < pair_num; ++pair) {
24
25
      int64_t source_offset =
          element_num_per_block * mapping_pairs[pair][0].item<int64_t>();
26
      int64_t target_offset =
27
          element_num_per_block * mapping_pairs[pair][1].item<int64_t>();
28
29
30
      scalar_t* key_cache_ptr = key_caches[layer].data_ptr<scalar_t>();
      scalar_t* source_ptr = key_cache_ptr + source_offset;
      scalar_t* target_ptr = key_cache_ptr + target_offset;
31
32
      std::memcpy(target_ptr, source_ptr, block_bytes);

33
      scalar_t* value_cache_ptr = value_caches[layer].data_ptr<scalar_t>();
34
35
36
37
38
39
40
41
42
      source_ptr = value_cache_ptr + source_offset;
      target_ptr = value_cache_ptr + target_offset;
      std::memcpy(target_ptr, source_ptr, block_bytes);
    }
  }
}

template <typename scalar_t>
void reshape_and_cache_cpu_impl(
43
44
45
    const scalar_t* __restrict__ key, const scalar_t* __restrict__ value,
    scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache,
    const int64_t* __restrict__ slot_mapping, const int num_tokens,
46
47
48
49
50
51
52
53
54
55
56
57
    const int key_stride, const int value_stride, const int num_heads,
    const int head_size, const int block_size, const int x) {
  const int block_elem_num = num_heads * head_size * block_size;

#pragma omp parallel for collapse(2)
  for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
    for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
      const int64_t slot_idx = slot_mapping[token_idx];
      if (slot_idx >= 0) {
        int src_key_head_idx = token_idx * key_stride + head_idx * head_size;
        int src_value_head_idx =
            token_idx * value_stride + head_idx * head_size;
58
59
        const scalar_t* src_key_head_ptr = key + src_key_head_idx;
        const scalar_t* src_value_head_ptr = value + src_value_head_idx;
60
61
        const int64_t block_index = slot_idx / block_size;
        const int64_t block_offset = slot_idx % block_size;
62
        scalar_t* target_key_head_ptr = key_cache +
63
64
                                        block_elem_num * block_index +
                                        head_idx * block_size * head_size;
65
        scalar_t* target_value_head_ptr = value_cache +
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
                                          block_elem_num * block_index +
                                          head_idx * block_size * head_size;

        for (int src_key_idx = 0; src_key_idx < head_size; src_key_idx += x) {
          const int64_t target_offset =
              src_key_idx * block_size + block_offset * x;
          for (int i = 0; i < x; ++i) {
            target_key_head_ptr[target_offset + i] =
                src_key_head_ptr[src_key_idx + i];
          }
        }

        for (int src_value_idx = 0; src_value_idx < head_size;
             ++src_value_idx) {
          const int64_t target_offset =
              src_value_idx * block_size + block_offset;
          target_value_head_ptr[target_offset] =
              src_value_head_ptr[src_value_idx];
        }
      }
    }
  }
}
89
};  // namespace
90

Thien Tran's avatar
Thien Tran committed
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
template <typename scalar_t>
void concat_and_cache_mla_cpu_impl(
    const scalar_t* __restrict__ kv_c,  // [num_tokens, kv_lora_rank]
    const scalar_t* __restrict__ k_pe,  // [num_tokens, pe_dim]
    scalar_t* __restrict__ kv_cache,  // [num_blocks, block_size, (kv_lora_rank
                                      // + pe_dim)]
    const int64_t* __restrict__ slot_mapping,  // [num_tokens]
    const int num_tokens,                      //
    const int block_stride,                    //
    const int entry_stride,                    //
    const int kv_c_stride,                     //
    const int k_pe_stride,                     //
    const int kv_lora_rank,                    //
    const int pe_dim,                          //
    const int block_size                       //
) {
#pragma omp parallel for
  for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
    const int64_t slot_idx = slot_mapping[token_idx];
    // NOTE: slot_idx can be -1 if the token is padded
    if (slot_idx < 0) {
      continue;
    }
    const int64_t block_idx = slot_idx / block_size;
    const int64_t block_offset = slot_idx % block_size;

    auto copy = [&](const scalar_t* __restrict__ src,
                    scalar_t* __restrict__ dst, int src_stride, int dst_stride,
                    int size, int offset) {
      for (int i = 0; i < size; i++) {
        const int64_t src_idx = token_idx * src_stride + i;
        const int64_t dst_idx =
            block_idx * block_stride + block_offset * entry_stride + i + offset;
        dst[dst_idx] = src[src_idx];
      }
    };

    copy(kv_c, kv_cache, kv_c_stride, block_stride, kv_lora_rank, 0);
    copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank);
  }
}

133
134
135
136
137
// Note: the key_caches and value_caches vectors are constant but
// not the Tensors they contain. The vectors need to be const refs
// in order to satisfy pytorch's C++ operator registration code.
void copy_blocks(std::vector<torch::Tensor> const& key_caches,
                 std::vector<torch::Tensor> const& value_caches,
138
                 const torch::Tensor& block_mapping) {
139
  unsigned num_layers = key_caches.size();
140
141
142
143
144
145
  TORCH_CHECK(num_layers == value_caches.size());
  if (num_layers == 0) {
    return;
  }

  const int element_num_per_block = key_caches[0][0].numel();
146
147
148
149
150
151
  DISPATCH_MACRO(key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] {
    CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl)
    copy_blocks_cpu_impl<scalar_t>(key_caches, value_caches, block_mapping,
                                   element_num_per_block, num_layers);
    CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl)
  });
152
153
}

154
155
156
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
                       torch::Tensor& key_cache, torch::Tensor& value_cache,
                       torch::Tensor& slot_mapping,
157
158
                       const std::string& kv_cache_dtype,
                       torch::Tensor& k_scale, torch::Tensor& v_scale) {
159
160
161
162
163
164
165
166
167
  int num_tokens = key.size(0);
  int num_heads = key.size(1);
  int head_size = key.size(2);
  int block_size = key_cache.size(3);
  int x = key_cache.size(4);

  int key_stride = key.stride(0);
  int value_stride = value.stride(0);

168
169
170
171
172
173
174
175
176
  DISPATCH_MACRO(key.scalar_type(), "reshape_and_cache_cpu_impl", [&] {
    CPU_KERNEL_GUARD_IN(reshape_and_cache_cpu_impl)
    reshape_and_cache_cpu_impl<scalar_t>(
        key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
        key_cache.data_ptr<scalar_t>(), value_cache.data_ptr<scalar_t>(),
        slot_mapping.data_ptr<int64_t>(), num_tokens, key_stride, value_stride,
        num_heads, head_size, block_size, x);
    CPU_KERNEL_GUARD_OUT(reshape_and_cache_cpu_impl)
  });
177
178
}

Thien Tran's avatar
Thien Tran committed
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
void concat_and_cache_mla(
    torch::Tensor& kv_c,          // [num_tokens, kv_lora_rank]
    torch::Tensor& k_pe,          // [num_tokens, pe_dim]
    torch::Tensor& kv_cache,      // [num_blocks, block_size, (kv_lora_rank +
                                  // pe_dim)]
    torch::Tensor& slot_mapping,  // [num_tokens] or [num_actual_tokens]
    const std::string& kv_cache_dtype, torch::Tensor& scale) {
  int num_tokens = slot_mapping.size(0);
  int kv_lora_rank = kv_c.size(1);
  int pe_dim = k_pe.size(1);
  int block_size = kv_cache.size(1);

  TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim);
  TORCH_CHECK(kv_cache_dtype != "fp8");

  int kv_c_stride = kv_c.stride(0);
  int k_pe_stride = k_pe.stride(0);
  int block_stride = kv_cache.stride(0);
  int entry_stride = kv_cache.stride(1);

  VLLM_DISPATCH_FLOATING_TYPES(
      kv_c.scalar_type(), "concat_and_cache_mla_cpu_impl", [&] {
        CPU_KERNEL_GUARD_IN(concat_and_cache_mla_cpu_impl)
        concat_and_cache_mla_cpu_impl<scalar_t>(
            kv_c.data_ptr<scalar_t>(), k_pe.data_ptr<scalar_t>(),
            kv_cache.data_ptr<scalar_t>(), slot_mapping.data_ptr<int64_t>(),
            num_tokens, block_stride, entry_stride, kv_c_stride, k_pe_stride,
            kv_lora_rank, pe_dim, block_size);
        CPU_KERNEL_GUARD_OUT(concat_and_cache_mla_cpu_impl)
      });
}

211
212
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
                 const torch::Tensor& block_mapping) {
213
214
  TORCH_CHECK(false, "swap_blocks is unsupported on CPU.")
}