Commit bf74389d authored by PanZezhong's avatar PanZezhong
Browse files

issue/168 get contiguous paged kv cache

parent f246c4f1
......@@ -188,8 +188,7 @@ std::tuple<infinicore::Tensor, infinicore::Tensor> PagedKVCache::update(
const infinicore::Tensor &v,
const infinicore::Tensor &slot_mapping) {
auto k_cache_layer = k_caches_->narrow({{0, layer_idx, 1}})->squeeze(0);
auto v_cache_layer = v_caches_->narrow({{0, layer_idx, 1}})->squeeze(0);
auto &&[k_cache_layer, v_cache_layer] = get_paged_kv(layer_idx);
infinicore::op::paged_caching_(k,
v,
......@@ -198,4 +197,69 @@ std::tuple<infinicore::Tensor, infinicore::Tensor> PagedKVCache::update(
slot_mapping);
return {k_cache_layer, v_cache_layer};
}
std::tuple<infinicore::Tensor, infinicore::Tensor>
PagedKVCache::get_paged_kv(size_t layer_idx) {
auto k_cache_layer = k_caches_->narrow({{0, layer_idx, 1}})->squeeze(0);
auto v_cache_layer = v_caches_->narrow({{0, layer_idx, 1}})->squeeze(0);
return {k_cache_layer, v_cache_layer};
}
std::tuple<infinicore::Tensor, infinicore::Tensor>
PagedKVCache::get_contiguous_kv(
size_t layer_idx,
const infinicore::Tensor block_tables,
const infinicore::Tensor cache_lens,
const infinicore::Tensor input_offsets,
size_t request_id) {
ASSERT_EQ(block_tables->dtype(), infinicore::DataType::I64);
ASSERT_EQ(cache_lens->dtype(), infinicore::DataType::I64);
ASSERT_EQ(input_offsets->dtype(), infinicore::DataType::I64);
auto nreq = block_tables->size(0);
auto block_tables_cpu = block_tables->to(infinicore::Device::cpu());
auto cache_lens_cpu = cache_lens->to(infinicore::Device::cpu());
auto input_offsets_cpu = input_offsets->to(infinicore::Device::cpu());
infinicore::context::syncDevice();
// [num_blocks, num_rank_v_heads, block_size, v_dim]
auto &&[k_cache_layer, v_cache_layer] = get_paged_kv(layer_idx);
auto req = request_id;
auto cache_lens_ptr = reinterpret_cast<const int64_t *>(cache_lens_cpu->data());
auto input_offsets_ptr = reinterpret_cast<const int64_t *>(input_offsets_cpu->data());
int64_t total_len = cache_lens_ptr[req] + (input_offsets_ptr[req + 1] - input_offsets_ptr[req]);
auto full_k = infinicore::Tensor::empty(
{num_rank_k_heads_, (size_t)total_len, k_dim_},
k_cache_layer->dtype(), k_cache_layer->device());
auto full_v = infinicore::Tensor::empty(
{num_rank_v_heads_, (size_t)total_len, v_dim_},
v_cache_layer->dtype(), v_cache_layer->device());
size_t nblocks = total_len / block_size_;
size_t r = total_len % block_size_;
for (size_t b = 0; b < nblocks; b++) {
size_t bid = *((int64_t *)(block_tables_cpu->narrow({{0, req, 1}, {1, b, 1}})->data()));
full_k->narrow({{1, b * block_size_, block_size_}})
->copy_from(k_cache_layer->narrow({{0, bid, 1}})->squeeze(0));
full_v->narrow({{1, b * block_size_, block_size_}})
->copy_from(v_cache_layer->narrow({{0, bid, 1}})->squeeze(0));
}
if (r > 0) {
size_t bid = *((int64_t *)(block_tables_cpu->narrow({{0, req, 1}, {1, nblocks, 1}})->data()));
full_k->narrow({{1, nblocks * block_size_, r}})
->copy_from(k_cache_layer->narrow({{0, bid, 1}})->squeeze(0)->narrow({{1, 0, r}}));
full_v->narrow({{1, nblocks * block_size_, r}})
->copy_from(v_cache_layer->narrow({{0, bid, 1}})->squeeze(0)->narrow({{1, 0, r}}));
}
return {full_k, full_v};
}
} // namespace infinilm::cache
......@@ -113,7 +113,7 @@ public:
/**
* @brief Update Paged KV cache at a given layer given slot info for each token.
*
* @param layer_idx Which transformer layer
* @param layer_idx Which paged attention layer
* @param k [num_rank_k_heads, seq_len, k_dim]
* @param v [num_rank_v_heads, seq_len, v_dim]
* @param slot_mapping [seq_len]
......@@ -128,7 +128,41 @@ public:
const infinicore::Tensor &v,
const infinicore::Tensor &slot_mapping);
~PagedKVCache() override = default;
/**
* @brief Get Paged KV cache at a given layer.
*
* @param layer_idx Which paged attention layer
*
* @return (full_k, full_v)
* full_k: [num_blocks, num_rank_k_heads, block_size, k_dim]
* full_v: [num_blocks, num_rank_v_heads, block_size, v_dim]
*/
std::tuple<infinicore::Tensor, infinicore::Tensor>
get_paged_kv(size_t layer_idx);
/**
* @brief Get contiguous KV cache at a given layer, given the request info
* among a continuous request batch.
*
* @param layer_idx Which paged attention layer
* @param block_tables [num_requests, max_blocks_per_request]
* @param cache_lens [num_requests]
* @param input_offsets [num_requests + 1]
* @param request_id Which request among a continuous batch of requests
*
* @return (full_k, full_v)
* full_k: [num_rank_k_heads, total_len, k_dim]
* full_v: [num_rank_v_heads, total_len, v_dim]
*/
std::tuple<infinicore::Tensor, infinicore::Tensor>
get_contiguous_kv(size_t layer_idx,
const infinicore::Tensor block_tables,
const infinicore::Tensor cache_lens,
const infinicore::Tensor input_offsets,
size_t request_id = 0);
~PagedKVCache() override
= default;
private:
infinicore::Size k_dim_;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment