kv_cache.hpp 2.48 KB
Newer Older
1
2
#pragma once

PanZezhong's avatar
PanZezhong committed
3
4
#include "base_cache.hpp"

5
#include "infinicore/context/context.hpp"
6
#include "infinicore/device.hpp"
7
#include "infinicore/tensor.hpp"
8

9
#include <algorithm>
PanZezhong's avatar
PanZezhong committed
10
#include <limits>
11
#include <memory>
12
13
#include <numeric>
#include <stdexcept>
14
#include <utility>
15

PanZezhong's avatar
PanZezhong committed
16
17
#include <spdlog/spdlog.h>

18
namespace infinilm::cache {
PanZezhong's avatar
PanZezhong committed
19
20
21
22
23
class StaticKVCacheConfig final : public CacheConfig {
public:
    StaticKVCacheConfig(
        infinicore::Size _max_batch_size = 1,
        infinicore::Size _max_cache_len = std::numeric_limits<infinicore::Size>::max());
24

PanZezhong's avatar
PanZezhong committed
25
26
27
    std::unique_ptr<CacheConfig> unique_copy() const override;
    infinicore::Size max_batch_size() const;
    infinicore::Size max_cache_len() const;
28

PanZezhong's avatar
PanZezhong committed
29
30
31
private:
    infinicore::Size max_batch_size_;
    infinicore::Size max_cache_len_;
32
33
};

PanZezhong's avatar
PanZezhong committed
34
class StaticKVCache final : public Cache {
Ceng's avatar
Ceng committed
35
public:
PanZezhong's avatar
PanZezhong committed
36
37
38
39
40
41
42
43
44
45
46
    StaticKVCache(

        infinicore::Size k_dim,
        infinicore::Size v_dim,
        infinicore::Size num_k_heads,
        infinicore::Size num_v_heads,
        infinicore::Size num_layers,
        infinicore::Size max_positional_embedding,
        infinicore::DataType dtype,
        const StaticKVCacheConfig &config,
        const engine::distributed::RankInfo &rank_info);
47

Ceng's avatar
Ceng committed
48
    /**
PanZezhong's avatar
PanZezhong committed
49
     * @brief Update KV cache at a given layer and cache position.
Ceng's avatar
Ceng committed
50
     *
PanZezhong's avatar
PanZezhong committed
51
52
53
54
     * @param layer_idx Which transformer layer
     * @param k         [batch, num_rank_k_heads, seq_len, k_dim]
     * @param v         [batch, num_rank_v_heads, seq_len, v_dim]
     * @param cache_pos Sequence position to write
Ceng's avatar
Ceng committed
55
     *
PanZezhong's avatar
PanZezhong committed
56
57
58
     * @return (full_k, full_v)
     *         full_k: [batch, num_rank_k_heads, cache_pos + seq_len, k_dim]
     *         full_v: [batch, num_rank_v_heads, cache_pos + seq_len, v_dim]
Ceng's avatar
Ceng committed
59
     */
PanZezhong's avatar
PanZezhong committed
60
61
62
63
64
    std::tuple<infinicore::Tensor, infinicore::Tensor>
    update(size_t layer_idx,
           const infinicore::Tensor &k,
           const infinicore::Tensor &v,
           const infinicore::Tensor &cache_positions);
Ceng's avatar
Ceng committed
65

PanZezhong's avatar
PanZezhong committed
66
    ~StaticKVCache() override = default;
Ceng's avatar
Ceng committed
67
68

private:
PanZezhong's avatar
PanZezhong committed
69
70
71
72
73
74
75
76
77
78
79
80
81
82
    infinicore::Size k_dim_;
    infinicore::Size v_dim_;
    infinicore::Size num_rank_k_heads_;
    infinicore::Size num_rank_v_heads_;
    infinicore::Size rank_batch_size_;
    infinicore::Size cache_len_;
    infinicore::Size rank_num_layers_;
    infinicore::DataType dtype_;

    // [num_layers, max_batch, num_rank_k_heads, max_cache_len, k_dim]
    infinicore::Tensor k_caches_;

    // [num_layers, max_batch, num_rank_v_heads, max_cache_len, v_dim]
    infinicore::Tensor v_caches_;
Ceng's avatar
Ceng committed
83
84
};

85
} // namespace infinilm::cache