kv_cache.hpp 6.58 KB
Newer Older
1
2
#pragma once

3
#include "infinicore/context/context.hpp"
4
#include "infinicore/device.hpp"
5
#include "infinicore/tensor.hpp"
6

7
8
#include <algorithm>
#include <memory>
9
10
#include <numeric>
#include <stdexcept>
11
#include <utility>
12

PanZezhong's avatar
PanZezhong committed
13
14
#include <spdlog/spdlog.h>

15
16
17
18
19
namespace infinilm::cache {

/**
 * @brief Simple KV cache structure for incremental decoding
 *
20
 * Stores key and value caches with shape [batch_size, n_kv_head, capacity, head_dim]
21
22
23
24
25
26
 * Similar to DynamicLayer in Python cache_utils.py
 *
 * This is a common component that can be used by any model architecture
 * that needs KV caching for attention mechanisms.
 */
struct KVCache {
27
28
29
30
31
32
33
    infinicore::Tensor k_cache;          // [batch_size, n_kv_head, capacity, head_dim]
    infinicore::Tensor v_cache;          // [batch_size, n_kv_head, capacity, head_dim]
    std::vector<size_t> cache_positions; // Current position in cache
    size_t max_capacity;                 // Maximum capacity of cache
    bool initialized;                    // Whether cache has been initialized

    KVCache() : max_capacity(0), initialized(false) {}
34
35
36
37
38
39
40
41
42

    /**
     * @brief Initialize or update cache capacity
     * @param num_kv_heads Number of key-value heads
     * @param head_dim Head dimension
     * @param seq_len Sequence length of new tokens
     * @param dtype Data type
     * @param device Device
     */
43
    void ensure_capacity(size_t batch_size, size_t num_kv_heads, size_t head_dim, size_t seq_len,
44
                         infinicore::DataType dtype, const infinicore::Device &device) {
45
        size_t required_capacity = seq_len + std::accumulate(cache_positions.begin(), cache_positions.end(), 0, [](int a, int b) { return std::max(a, b); });
46
47
48

        // Lazy initialization
        if (!initialized) {
49
            max_capacity = std::max(required_capacity, size_t(4096)); // Start with at least 4096
50
            k_cache = infinicore::Tensor::empty({batch_size, num_kv_heads, max_capacity, head_dim},
51
                                                dtype, device);
52
            v_cache = infinicore::Tensor::empty({batch_size, num_kv_heads, max_capacity, head_dim},
53
                                                dtype, device);
54
            cache_positions = std::vector<size_t>(batch_size, 0);
55
56
57
58
            initialized = true;
        }
        // Grow cache if needed (similar to DynamicLayer in Python)
        else if (required_capacity > max_capacity) {
59
60
61
62
63
64
65
66
67
            size_t new_capacity = std::max(max_capacity * 2, required_capacity + max_capacity);
            size_t new_batch_size = std::max(batch_size, k_cache->shape()[0]);
            if (num_kv_heads != k_cache->shape()[1] || head_dim != k_cache->shape()[3]) {
                throw std::runtime_error("KVCache ensure_capacity: num_kv_heads or head_dim mismatch with existing cache.");
            }
            if (new_batch_size > cache_positions.size()) {
                cache_positions.resize(new_batch_size, 0);
            }
            auto k_new = infinicore::Tensor::empty({new_batch_size, num_kv_heads, new_capacity, head_dim},
68
                                                   dtype, device);
69
            auto v_new = infinicore::Tensor::empty({new_batch_size, num_kv_heads, new_capacity, head_dim},
70
71
72
                                                   dtype, device);

            // Copy existing cache data
73
74
75
76
77
78
79
80
            for (size_t b = 0; b < new_batch_size; ++b) {
                size_t cache_position = cache_positions[b];
                if (cache_position > 0) {
                    auto k_slice = k_cache->narrow({{0, b, 1}, {2, 0, cache_position}});
                    auto v_slice = v_cache->narrow({{0, b, 1}, {2, 0, cache_position}});
                    k_new->narrow({{0, b, 1}, {2, 0, cache_position}})->copy_from(k_slice);
                    v_new->narrow({{0, b, 1}, {2, 0, cache_position}})->copy_from(v_slice);
                }
81
82
83
84
85
86
87
88
            }

            k_cache = k_new;
            v_cache = v_new;
            max_capacity = new_capacity;
        }
    }

89
90
91
92
93
94
    KVCache(size_t max_batch_size, size_t n_kv_head, size_t head_dim, infinicore::DataType dtype, size_t max_seqlen = 4096, infinicore::Device device = infinicore::context::getDevice())
        : max_capacity(max_seqlen), initialized(false) {
        cache_positions = std::vector<size_t>(max_batch_size, 0);
        ensure_capacity(max_batch_size, n_kv_head, head_dim, max_capacity, dtype, device);
    }

95
96
    /**
     * @brief Update cache with new key and value states
97
98
     * @param k_new New key states [batch_size, n_kv_head, seq_len, head_dim]
     * @param v_new New value states [batch_size, n_kv_head, seq_len, head_dim]
99
100
101
102
103
104
105
106
     * @return Tuple of (k_total, v_total) with shape [n_kv_head, total_seq_len, head_dim]
     *
     * Note: This method writes to the cache. If using with attention op, the attention op
     * also writes to the cache, so this should be called AFTER attention, not before.
     */
    std::pair<infinicore::Tensor, infinicore::Tensor> update(
        const infinicore::Tensor &k_new,
        const infinicore::Tensor &v_new) {
107
108
109
110
111
112
113
        if (k_new->ndim() != 4 || v_new->ndim() != 4) {
            throw std::runtime_error("KVCache update: k_new and v_new must be 4D tensors in [batch_size, n_kv_head, seq_len, head_dim] form.");
        }
        size_t batch_size = k_new->shape()[0];
        size_t num_kv_heads = k_new->shape()[1];
        size_t seq_len = k_new->shape()[2];
        size_t head_dim = k_new->shape()[3];
114
115

        // Ensure capacity
116
        ensure_capacity(batch_size, num_kv_heads, head_dim, seq_len,
117
                        k_new->dtype(), k_new->device());
118
119

        // Copy new k/v into cache at current position
120
121
122
        bool all_equal = cache_positions.empty() || std::equal(cache_positions.begin() + 1, cache_positions.end(), cache_positions.begin());
        if (all_equal) {
            auto cache_position = cache_positions[0];
123

124
125
126
127
            auto k_dst = k_cache->narrow({{2, cache_position, seq_len}});
            auto v_dst = v_cache->narrow({{2, cache_position, seq_len}});
            k_dst->copy_from(k_new);
            v_dst->copy_from(v_new);
128

129
130
131
132
133
            // Update position
            cache_position += seq_len;
            for (size_t b = 0; b < batch_size; ++b) {
                cache_positions[b] = cache_position;
            }
134

135
136
137
138
139
140
141
142
            // Return the total cache up to current position
            auto k_total = k_cache->narrow({{2, 0, cache_position}});
            auto v_total = v_cache->narrow({{2, 0, cache_position}});

            return std::make_pair(k_total, v_total);
        } else {
            throw std::runtime_error("KVCache update: cache positions must be equal among a batch.");
        }
143
144
145
    }
};

146
} // namespace infinilm::cache