#include "kv_cache.hpp" #include "../utils.hpp" #include "infinicore/ops.hpp" #include namespace infinilm::cache { // ========================== // StaticKVCacheConfig // ========================== StaticKVCacheConfig::StaticKVCacheConfig( infinicore::Size _max_batch_size, infinicore::Size _max_cache_len) : max_batch_size_(_max_batch_size), max_cache_len_(_max_cache_len) { } std::unique_ptr StaticKVCacheConfig::unique_copy() const { return std::make_unique(*this); } infinicore::Size StaticKVCacheConfig::max_batch_size() const { return max_batch_size_; } infinicore::Size StaticKVCacheConfig::max_cache_len() const { return max_cache_len_; } // ========================== // StaticKVCache // ========================== StaticKVCache::StaticKVCache( infinicore::Size k_dim, infinicore::Size v_dim, infinicore::Size num_k_heads, infinicore::Size num_v_heads, infinicore::Size num_layers, infinicore::Size max_positional_embedding, infinicore::DataType dtype, const StaticKVCacheConfig &config, const engine::distributed::RankInfo &rank_info) : Cache(), k_dim_(k_dim), v_dim_(v_dim), num_rank_k_heads_(num_k_heads / rank_info.tp_size), num_rank_v_heads_(num_v_heads / rank_info.tp_size), rank_batch_size_(config.max_batch_size()), cache_len_(config.max_cache_len() == std::numeric_limits::max() || config.max_cache_len() == 0 ? max_positional_embedding : config.max_cache_len()), rank_num_layers_(num_layers), dtype_(dtype) { // Allocate K cache k_caches_ = infinicore::Tensor::empty( {rank_num_layers_, rank_batch_size_, num_rank_k_heads_, cache_len_, k_dim_}, dtype_, rank_info.device); // Allocate V cache v_caches_ = infinicore::Tensor::empty( {rank_num_layers_, rank_batch_size_, num_rank_v_heads_, cache_len_, v_dim_}, dtype_, rank_info.device); } std::tuple StaticKVCache::update(size_t layer_idx, const infinicore::Tensor &k, const infinicore::Tensor &v, const infinicore::Tensor &cache_lengths) { ASSERT(layer_idx < rank_num_layers_); auto batch_size = k->size(0); auto update_len = k->size(2); size_t cache_pos = reinterpret_cast(cache_lengths->to(infinicore::Device::cpu())->data())[0]; auto result_len = cache_pos + update_len; ASSERT(result_len <= cache_len_); ASSERT_EQ(batch_size, rank_batch_size_); auto k_cache_layer = k_caches_->narrow({{0, layer_idx, 1}})->squeeze(0); auto v_cache_layer = v_caches_->narrow({{0, layer_idx, 1}})->squeeze(0); auto k_cache_update = k_cache_layer->narrow({{2, cache_pos, update_len}}); auto v_cache_update = v_cache_layer->narrow({{2, cache_pos, update_len}}); k_cache_update->copy_from(k); v_cache_update->copy_from(v); auto k_total = k_cache_layer->narrow({{2, 0, result_len}}); auto v_total = v_cache_layer->narrow({{2, 0, result_len}}); return {k_total, v_total}; } // ========================== // PagedKVCacheConfig // ========================== PagedKVCacheConfig::PagedKVCacheConfig( size_t max_kv_memory_bytes, size_t block_size) : max_kv_memory_bytes_(max_kv_memory_bytes), block_size_(block_size) { } std::unique_ptr PagedKVCacheConfig::unique_copy() const { return std::make_unique(*this); } size_t PagedKVCacheConfig::max_kv_memory_bytes() const { return max_kv_memory_bytes_; } size_t PagedKVCacheConfig::block_size() const { return block_size_; } // ========================== // PagedKVCache // ========================== PagedKVCache::PagedKVCache( infinicore::Size k_dim, infinicore::Size v_dim, infinicore::Size num_k_heads, infinicore::Size num_v_heads, infinicore::Size num_layers, infinicore::DataType dtype, const PagedKVCacheConfig &config, const engine::distributed::RankInfo &rank_info) : Cache(), k_dim_(k_dim), v_dim_(v_dim), num_rank_k_heads_(num_k_heads / rank_info.tp_size), num_rank_v_heads_(num_v_heads / rank_info.tp_size), rank_num_layers_(num_layers), dtype_(dtype), block_size_(config.block_size()) { num_blocks_per_layer_ = config.max_kv_memory_bytes() / (k_dim * num_rank_k_heads_ + v_dim * num_rank_v_heads_) / block_size_ / rank_num_layers_ / infinicore::dsize(dtype_); if (num_blocks_per_layer_ == 0) { throw std::runtime_error("Not enough memory for KV cache"); } // [num_layers, num_blocks, num_rank_k_heads, block_size, k_dim] k_caches_ = infinicore::Tensor::empty( {rank_num_layers_, num_blocks_per_layer_, num_rank_k_heads_, block_size_, k_dim_}, dtype_, rank_info.device); // [num_layers, num_blocks, num_rank_v_heads, block_size, v_dim] v_caches_ = infinicore::Tensor::empty( {rank_num_layers_, num_blocks_per_layer_, num_rank_v_heads_, block_size_, v_dim_}, dtype_, rank_info.device); } std::tuple PagedKVCache::update( size_t layer_idx, const infinicore::Tensor &k, const infinicore::Tensor &v, const infinicore::Tensor &slot_mapping) { auto k_cache_layer = k_caches_->narrow({{0, layer_idx, 1}})->squeeze(0); auto v_cache_layer = v_caches_->narrow({{0, layer_idx, 1}})->squeeze(0); infinicore::op::paged_caching_(k, v, k_cache_layer, v_cache_layer, slot_mapping); return {k_cache_layer, v_cache_layer}; } } // namespace infinilm::cache