kv_cache.hpp 15.6 KB
Newer Older
1
2
#pragma once

3
#include "infinicore/context/context.hpp"
4
#include "infinicore/device.hpp"
5
#include "infinicore/tensor.hpp"
6

7
8
#include "cache_config.hpp"

9
10
#include <algorithm>
#include <memory>
11
12
#include <numeric>
#include <stdexcept>
13
#include <utility>
14

PanZezhong's avatar
PanZezhong committed
15
16
#include <spdlog/spdlog.h>

17
18
19
namespace infinilm::cache {

/**
Ceng's avatar
Ceng committed
20
 * @brief Single layer's KV cache for incremental decoding
21
 *
22
 * Stores key and value caches with shape [batch_size, n_kv_head, capacity, head_dim]
23
24
 * Similar to DynamicLayer in Python cache_utils.py
 *
Ceng's avatar
Ceng committed
25
 * This represents a single layer's cache within a model-level cache container.
26
 */
Ceng's avatar
Ceng committed
27
struct KVCacheLayer {
28
29
30
31
    infinicore::Tensor k_cache;          // [batch_size, n_kv_head, capacity, head_dim]
    infinicore::Tensor v_cache;          // [batch_size, n_kv_head, capacity, head_dim]
    std::vector<size_t> cache_positions; // Current position in cache
    size_t max_capacity;                 // Maximum capacity of cache
32
33
34
    size_t initial_capacity;             // Initial capacity from config
    size_t initial_batch_size;           // Initial batch size from config
    float growth_factor;                 // Growth factor for dynamic resizing
35
36
    bool initialized;                    // Whether cache has been initialized

37
38
    KVCacheLayer() : max_capacity(0), initial_capacity(4096), initial_batch_size(1),
                     growth_factor(2.0f), initialized(false) {}
39
40

    /**
41
42
     * @brief Initialize or update cache capacity with config parameters
     * @param batch_size Current batch size
43
44
45
46
47
     * @param num_kv_heads Number of key-value heads
     * @param head_dim Head dimension
     * @param seq_len Sequence length of new tokens
     * @param dtype Data type
     * @param device Device
48
     * @param cache_config Cache configuration parameters
49
     */
50
    void ensure_capacity(size_t batch_size, size_t num_kv_heads, size_t head_dim, size_t seq_len,
Ceng's avatar
Ceng committed
51
                         infinicore::DataType dtype, const infinicore::Device &device,
52
                         const CacheConfig &cache_config) {
53
        size_t required_capacity = seq_len + std::accumulate(cache_positions.begin(), cache_positions.end(), 0, [](int a, int b) { return std::max(a, b); });
54

Ceng's avatar
Ceng committed
55
56
57
        // VALIDATION: Verify input parameters
        if (num_kv_heads == 0 || head_dim == 0 || seq_len == 0) {
            SPDLOG_ERROR("KVCacheLayer::ensure_capacity: Invalid parameters - num_kv_heads: {}, head_dim: {}, seq_len: {}",
58
                         num_kv_heads, head_dim, seq_len);
Ceng's avatar
Ceng committed
59
60
61
            throw std::runtime_error("KV cache ensure_capacity: invalid parameters");
        }

62
63
64
65
66
67
68
        // Store config parameters on first initialization
        if (!initialized) {
            initial_capacity = cache_config.initial_capacity;
            initial_batch_size = cache_config.initial_batch_size;
            growth_factor = cache_config.growth_factor;
        }

69
70
        // Lazy initialization
        if (!initialized) {
71
72
73
74
75
76
77
            // Use max of required capacity and initial capacity from config
            max_capacity = std::max(required_capacity, initial_capacity);

            // Use max of current batch size and initial batch size from config
            size_t alloc_batch_size = std::max(batch_size, initial_batch_size);

            k_cache = infinicore::Tensor::empty({alloc_batch_size, num_kv_heads, max_capacity, head_dim},
78
                                                dtype, device);
79
            v_cache = infinicore::Tensor::empty({alloc_batch_size, num_kv_heads, max_capacity, head_dim},
80
                                                dtype, device);
81
            cache_positions = std::vector<size_t>(alloc_batch_size, 0);
82
            initialized = true;
Ceng's avatar
Ceng committed
83

84
85
86
            spdlog::debug("Initialized KV cache with batch_size={}, capacity={} (config: initial_batch={}, initial_capacity={})",
                          alloc_batch_size, max_capacity, initial_batch_size, initial_capacity);

Ceng's avatar
Ceng committed
87
            // VALIDATION: Verify cache was created correctly
88
89
            if (k_cache->shape()[0] != alloc_batch_size || k_cache->shape()[1] != num_kv_heads || k_cache->shape()[2] != max_capacity || k_cache->shape()[3] != head_dim) {
                SPDLOG_ERROR("KVCacheLayer::ensure_capacity: Cache shape mismatch after initialization");
Ceng's avatar
Ceng committed
90
91
                throw std::runtime_error("KV cache initialization: shape mismatch");
            }
92
        }
93
        // Grow cache if needed using growth factor from config
94
        else if (required_capacity > max_capacity) {
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
            if (!cache_config.allow_expand) {
                SPDLOG_ERROR("KVCacheLayer::ensure_capacity: Cache expansion not allowed by config");
                throw std::runtime_error("KV cache expansion not allowed");
            }
            // Calculate new capacity using growth factor
            size_t new_capacity = static_cast<size_t>(
                std::max(static_cast<float>(max_capacity) * growth_factor,
                         static_cast<float>(required_capacity + max_capacity)));

            // Ensure we don't exceed max_position_embeddings if specified
            if (cache_config.max_kv_cache_length != 0) {
                new_capacity = std::min(new_capacity, cache_config.max_kv_cache_length);
            }

            // Ensure we grow by at least some minimum amount
            size_t min_growth = 256;
            if (new_capacity - max_capacity < min_growth) {
                new_capacity = max_capacity + min_growth;
            }

115
116
117
118
119
120
121
            size_t new_batch_size = std::max(batch_size, k_cache->shape()[0]);
            if (num_kv_heads != k_cache->shape()[1] || head_dim != k_cache->shape()[3]) {
                throw std::runtime_error("KVCache ensure_capacity: num_kv_heads or head_dim mismatch with existing cache.");
            }
            if (new_batch_size > cache_positions.size()) {
                cache_positions.resize(new_batch_size, 0);
            }
122

123
            auto k_new = infinicore::Tensor::empty({new_batch_size, num_kv_heads, new_capacity, head_dim},
124
                                                   dtype, device);
125
            auto v_new = infinicore::Tensor::empty({new_batch_size, num_kv_heads, new_capacity, head_dim},
126
127
                                                   dtype, device);

128
129
130
            spdlog::debug("Growing KV cache from capacity {} to {} (growth_factor={})",
                          max_capacity, new_capacity, growth_factor);

131
            // Copy existing cache data
132
133
134
135
136
137
138
139
            for (size_t b = 0; b < new_batch_size; ++b) {
                size_t cache_position = cache_positions[b];
                if (cache_position > 0) {
                    auto k_slice = k_cache->narrow({{0, b, 1}, {2, 0, cache_position}});
                    auto v_slice = v_cache->narrow({{0, b, 1}, {2, 0, cache_position}});
                    k_new->narrow({{0, b, 1}, {2, 0, cache_position}})->copy_from(k_slice);
                    v_new->narrow({{0, b, 1}, {2, 0, cache_position}})->copy_from(v_slice);
                }
140
141
142
143
144
            }

            k_cache = k_new;
            v_cache = v_new;
            max_capacity = new_capacity;
Ceng's avatar
Ceng committed
145
146
147

            // VALIDATION: Verify cache was grown correctly
            if (k_cache->shape()[2] != new_capacity) {
148
                SPDLOG_ERROR("KVCacheLayer::ensure_capacity: New cache capacity mismatch");
Ceng's avatar
Ceng committed
149
150
151
152
153
154
                throw std::runtime_error("KV cache growth: capacity mismatch");
            }
        }

        // VALIDATION: Final check that capacity is sufficient
        if (required_capacity > max_capacity) {
155
            SPDLOG_ERROR("KVCacheLayer::ensure_capacity: Capacity still insufficient after growth");
Ceng's avatar
Ceng committed
156
            throw std::runtime_error("KV cache ensure_capacity: capacity insufficient");
157
158
159
160
161
        }
    }

    /**
     * @brief Update cache with new key and value states
162
163
     * @param k_new New key states [batch_size, n_kv_head, seq_len, head_dim]
     * @param v_new New value states [batch_size, n_kv_head, seq_len, head_dim]
164
     * @param cache_config Cache configuration for capacity management
Ceng's avatar
Ceng committed
165
     * @return Tuple of (k_total, v_total) with shape [batch_size, n_kv_head, total_seq_len, head_dim]
166
167
168
     */
    std::pair<infinicore::Tensor, infinicore::Tensor> update(
        const infinicore::Tensor &k_new,
169
170
        const infinicore::Tensor &v_new,
        const CacheConfig &cache_config) {
171
        if (k_new->ndim() != 4 || v_new->ndim() != 4) {
172
            throw std::runtime_error("KVCache update: k_new and v_new must be 4D tensors");
173
174
175
176
177
        }
        size_t batch_size = k_new->shape()[0];
        size_t num_kv_heads = k_new->shape()[1];
        size_t seq_len = k_new->shape()[2];
        size_t head_dim = k_new->shape()[3];
178

179
        // Ensure capacity with cache config
180
        ensure_capacity(batch_size, num_kv_heads, head_dim, seq_len,
181
                        k_new->dtype(), k_new->device(), cache_config);
182
183

        // Copy new k/v into cache at current position
184
185
186
        bool all_equal = cache_positions.empty() || std::equal(cache_positions.begin() + 1, cache_positions.end(), cache_positions.begin());
        if (all_equal) {
            auto cache_position = cache_positions[0];
187

188
189
190
191
            auto k_dst = k_cache->narrow({{2, cache_position, seq_len}});
            auto v_dst = v_cache->narrow({{2, cache_position, seq_len}});
            k_dst->copy_from(k_new);
            v_dst->copy_from(v_new);
192

193
194
195
196
197
            // Update position
            cache_position += seq_len;
            for (size_t b = 0; b < batch_size; ++b) {
                cache_positions[b] = cache_position;
            }
198

199
200
201
202
203
204
205
206
            // Return the total cache up to current position
            auto k_total = k_cache->narrow({{2, 0, cache_position}});
            auto v_total = v_cache->narrow({{2, 0, cache_position}});

            return std::make_pair(k_total, v_total);
        } else {
            throw std::runtime_error("KVCache update: cache positions must be equal among a batch.");
        }
207
208
209
    }
};

Ceng's avatar
Ceng committed
210
211
212
213
214
215
216
217
/**
 * @brief Model-level KV cache container (similar to DynamicCache in Python)
 *
 * Stores a list of KVCacheLayer objects, one per model layer.
 * This aligns with Python backend's DynamicCache architecture.
 */
class DynamicCache {
public:
218
219
220
221
222
223
224
225
226
227
228
    /**
     * @brief Construct DynamicCache with cache configuration
     * @param cache_config Cache configuration parameters
     */
    DynamicCache(const CacheConfig &cache_config)
        : cache_config_(cache_config), layers_(cache_config.num_layers) {
        if (cache_config.num_layers == -1) {
            throw std::runtime_error("DynamicCache: num_layers must be specified in CacheConfig");
        }
    }

Ceng's avatar
Ceng committed
229
230
231
232
233
234
235
    /**
     * @brief Construct DynamicCache with specified number of layers
     *
     * @param num_layers Number of model layers (creates one cache layer per model layer)
     * @param max_position_embeddings Maximum position embeddings (used for initial capacity)
     */
    DynamicCache(size_t num_layers, size_t max_position_embeddings = 4096)
236
        : cache_config_(CacheConfig(CacheType::DYNAMIC, num_layers, max_position_embeddings)), layers_(num_layers) {}
Ceng's avatar
Ceng committed
237
238
239
240
241
242
243
244
245
246

    /**
     * @brief Update cache with new key and value states for a specific layer
     */
    std::pair<infinicore::Tensor, infinicore::Tensor> update(
        size_t layer_idx,
        const infinicore::Tensor &k_new,
        const infinicore::Tensor &v_new) {
        if (layer_idx >= layers_.size()) {
            SPDLOG_ERROR("DynamicCache::update: layer_idx {} out of range (num_layers: {})",
247
                         layer_idx, layers_.size());
Ceng's avatar
Ceng committed
248
249
250
            throw std::runtime_error("DynamicCache: layer_idx out of range");
        }

251
252
        // Update the cache for this layer with cache config
        return layers_[layer_idx].update(k_new, v_new, cache_config_);
Ceng's avatar
Ceng committed
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
    }

    /**
     * @brief Update cache with new key and value states (convenience method without layer_idx)
     * This is used when the cache is accessed directly without layer information
     *
     * @param k_new New key states [batch_size, n_kv_head, seq_len, head_dim]
     * @param v_new New value states [batch_size, n_kv_head, seq_len, head_dim]
     * @return Tuple of (k_total, v_total) with shape [batch_size, n_kv_head, total_seq_len, head_dim]
     *
     * Note: This assumes layer_idx=0. For multi-layer models, use update(layer_idx, k_new, v_new) instead.
     */
    std::pair<infinicore::Tensor, infinicore::Tensor> update(
        const infinicore::Tensor &k_new,
        const infinicore::Tensor &v_new) {
        return update(0, k_new, v_new);
    }

271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
    /**
     * @brief Get cache configuration
     */
    const CacheConfig &get_config() const { return cache_config_; }

    /**
     * @brief Update cache configuration (for dynamic reconfiguration)
     */
    void update_config(const CacheConfig &new_config) {
        // Check if we need to rebuild
        bool need_rebuild = false;

        // Rebuild if number of layers changed
        if (new_config.num_layers != cache_config_.num_layers || new_config.initial_batch_size != cache_config_.initial_batch_size) {
            need_rebuild = true;
            layers_.resize(new_config.num_layers);
        }

        // Rebuild if reset mode is RECREATE
        if (new_config.reset_mode == CacheResetMode::RECREATE) {
            need_rebuild = true;
        }

        // Update configuration
        cache_config_ = new_config;

        if (need_rebuild) {
            // Clear all layers to force reinitialization on next use
            for (auto &layer : layers_) {
                layer.initialized = false;
                layer.max_capacity = 0;
                // Tensors will be recreated when ensure_capacity is called
            }
            spdlog::info("DynamicCache configuration updated - cache will be rebuilt on next use");
        } else {
            spdlog::info("DynamicCache configuration updated: layers={}, initial_capacity={}, growth_factor={}",
                         new_config.num_layers, new_config.initial_capacity, new_config.growth_factor);
        }
    }

Ceng's avatar
Ceng committed
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
    /**
     * @brief Get the number of layers in this cache
     */
    size_t num_layers() const { return layers_.size(); }

    /**
     * @brief Get cache position for a specific layer
     */
    size_t cache_position(size_t layer_idx) const {
        if (layer_idx >= layers_.size()) {
            throw std::runtime_error("DynamicCache: layer_idx out of range");
        }
        if (layers_[layer_idx].cache_positions.empty()) {
            return 0;
        }
        return layers_[layer_idx].cache_positions[0]; // All batch items should have same position
    }

    /**
     * @brief Get max position embeddings (used for initial capacity)
     */
332
    size_t max_kv_cache_length() const { return cache_config_.max_kv_cache_length; }
Ceng's avatar
Ceng committed
333
334
335
336
337
338
339

    /**
     * @brief Reset cache for all layers to a specific position
     * This should be called when starting a new generation sequence or resetting to a specific position
     * @param pos Position to reset to (defaults to 0)
     */
    void reset(size_t pos = 0) {
340
        for (auto &layer : layers_) {
Ceng's avatar
Ceng committed
341
342
343
344
345
346
347
348
349
            std::fill(layer.cache_positions.begin(), layer.cache_positions.end(), pos);
            // Note: We don't reset initialized flag or clear the cache tensors
            // to avoid reallocation. The cache will be overwritten on next update.
        }
    }

    /**
     * @brief Access a specific layer's cache (for advanced usage)
     */
350
    KVCacheLayer &layer(size_t layer_idx) {
Ceng's avatar
Ceng committed
351
352
353
354
355
356
        if (layer_idx >= layers_.size()) {
            throw std::runtime_error("DynamicCache: layer_idx out of range");
        }
        return layers_[layer_idx];
    }

357
    const KVCacheLayer &layer(size_t layer_idx) const {
Ceng's avatar
Ceng committed
358
359
360
361
362
363
364
        if (layer_idx >= layers_.size()) {
            throw std::runtime_error("DynamicCache: layer_idx out of range");
        }
        return layers_[layer_idx];
    }

private:
365
    CacheConfig cache_config_;
Ceng's avatar
Ceng committed
366
367
368
    std::vector<KVCacheLayer> layers_;
};

369
} // namespace infinilm::cache