kv_cache.hpp 5.73 KB
Newer Older
1
2
#pragma once

PanZezhong's avatar
PanZezhong committed
3
4
#include "base_cache.hpp"

5
#include "infinicore/context/context.hpp"
6
#include "infinicore/device.hpp"
7
#include "infinicore/tensor.hpp"
8

9
#include <algorithm>
PanZezhong's avatar
PanZezhong committed
10
#include <limits>
11
#include <memory>
12
13
#include <numeric>
#include <stdexcept>
14
#include <utility>
15

PanZezhong's avatar
PanZezhong committed
16
17
#include <spdlog/spdlog.h>

18
namespace infinilm::cache {
PanZezhong's avatar
PanZezhong committed
19
20
21
22
23
class StaticKVCacheConfig final : public CacheConfig {
public:
    StaticKVCacheConfig(
        infinicore::Size _max_batch_size = 1,
        infinicore::Size _max_cache_len = std::numeric_limits<infinicore::Size>::max());
24

PanZezhong's avatar
PanZezhong committed
25
26
27
    std::unique_ptr<CacheConfig> unique_copy() const override;
    infinicore::Size max_batch_size() const;
    infinicore::Size max_cache_len() const;
28

PanZezhong's avatar
PanZezhong committed
29
30
31
private:
    infinicore::Size max_batch_size_;
    infinicore::Size max_cache_len_;
32
33
};

PanZezhong's avatar
PanZezhong committed
34
class StaticKVCache final : public Cache {
Ceng's avatar
Ceng committed
35
public:
PanZezhong's avatar
PanZezhong committed
36
37
38
39
40
41
42
43
44
45
46
    StaticKVCache(

        infinicore::Size k_dim,
        infinicore::Size v_dim,
        infinicore::Size num_k_heads,
        infinicore::Size num_v_heads,
        infinicore::Size num_layers,
        infinicore::Size max_positional_embedding,
        infinicore::DataType dtype,
        const StaticKVCacheConfig &config,
        const engine::distributed::RankInfo &rank_info);
47

Ceng's avatar
Ceng committed
48
    /**
PanZezhong's avatar
PanZezhong committed
49
     * @brief Update KV cache at a given layer and cache position.
Ceng's avatar
Ceng committed
50
     *
PanZezhong's avatar
PanZezhong committed
51
52
53
54
     * @param layer_idx Which transformer layer
     * @param k         [batch, num_rank_k_heads, seq_len, k_dim]
     * @param v         [batch, num_rank_v_heads, seq_len, v_dim]
     * @param cache_pos Sequence position to write
Ceng's avatar
Ceng committed
55
     *
PanZezhong's avatar
PanZezhong committed
56
57
58
     * @return (full_k, full_v)
     *         full_k: [batch, num_rank_k_heads, cache_pos + seq_len, k_dim]
     *         full_v: [batch, num_rank_v_heads, cache_pos + seq_len, v_dim]
Ceng's avatar
Ceng committed
59
     */
PanZezhong's avatar
PanZezhong committed
60
61
62
63
    std::tuple<infinicore::Tensor, infinicore::Tensor>
    update(size_t layer_idx,
           const infinicore::Tensor &k,
           const infinicore::Tensor &v,
64
           const infinicore::Tensor &past_sequence_lengths);
Ceng's avatar
Ceng committed
65

PanZezhong's avatar
PanZezhong committed
66
    ~StaticKVCache() override = default;
Ceng's avatar
Ceng committed
67
68

private:
PanZezhong's avatar
PanZezhong committed
69
70
71
72
73
74
75
76
77
78
79
80
81
82
    infinicore::Size k_dim_;
    infinicore::Size v_dim_;
    infinicore::Size num_rank_k_heads_;
    infinicore::Size num_rank_v_heads_;
    infinicore::Size rank_batch_size_;
    infinicore::Size cache_len_;
    infinicore::Size rank_num_layers_;
    infinicore::DataType dtype_;

    // [num_layers, max_batch, num_rank_k_heads, max_cache_len, k_dim]
    infinicore::Tensor k_caches_;

    // [num_layers, max_batch, num_rank_v_heads, max_cache_len, v_dim]
    infinicore::Tensor v_caches_;
Ceng's avatar
Ceng committed
83
84
};

85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
class PagedKVCacheConfig final : public CacheConfig {
public:
    PagedKVCacheConfig(
        size_t max_kv_memory_bytes,
        size_t block_size = 16);

    std::unique_ptr<CacheConfig> unique_copy() const override;
    size_t max_kv_memory_bytes() const;
    size_t block_size() const;

private:
    size_t max_kv_memory_bytes_;
    size_t block_size_;
};

class PagedKVCache final : public Cache {
public:
    PagedKVCache(

        infinicore::Size k_dim,
        infinicore::Size v_dim,
        infinicore::Size num_k_heads,
        infinicore::Size num_v_heads,
        infinicore::Size num_layers,
        infinicore::DataType dtype,
        const PagedKVCacheConfig &config,
        const engine::distributed::RankInfo &rank_info);

    /**
     * @brief Update Paged KV cache at a given layer given slot info for each token.
     *
116
     * @param layer_idx Which paged attention layer
117
118
119
120
121
122
123
124
125
126
127
128
129
130
     * @param k         [num_rank_k_heads, seq_len, k_dim]
     * @param v         [num_rank_v_heads, seq_len, v_dim]
     * @param slot_mapping [seq_len]
     *
     * @return (full_k, full_v)
     *         full_k: [num_blocks, num_rank_k_heads, block_size, k_dim]
     *         full_v: [num_blocks, num_rank_v_heads, block_size, v_dim]
     */
    std::tuple<infinicore::Tensor, infinicore::Tensor>
    update(size_t layer_idx,
           const infinicore::Tensor &k,
           const infinicore::Tensor &v,
           const infinicore::Tensor &slot_mapping);

131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
    /**
     * @brief Get Paged KV cache at a given layer.
     *
     * @param layer_idx Which paged attention layer
     *
     * @return (full_k, full_v)
     *         full_k: [num_blocks, num_rank_k_heads, block_size, k_dim]
     *         full_v: [num_blocks, num_rank_v_heads, block_size, v_dim]
     */
    std::tuple<infinicore::Tensor, infinicore::Tensor>
    get_paged_kv(size_t layer_idx);

    /**
     * @brief Get contiguous KV cache at a given layer, given the request info
     * among a continuous request batch.
     *
     * @param layer_idx Which paged attention layer
     * @param block_tables [num_requests, max_blocks_per_request]
     * @param cache_lens [num_requests]
     * @param input_offsets [num_requests + 1]
     * @param request_id Which request among a continuous batch of requests
     *
     * @return (full_k, full_v)
     *         full_k: [num_rank_k_heads, total_len, k_dim]
     *         full_v: [num_rank_v_heads, total_len, v_dim]
     */
    std::tuple<infinicore::Tensor, infinicore::Tensor>
    get_contiguous_kv(size_t layer_idx,
                      const infinicore::Tensor block_tables,
                      const infinicore::Tensor cache_lens,
                      const infinicore::Tensor input_offsets,
                      size_t request_id = 0);

    ~PagedKVCache() override
        = default;
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182

private:
    infinicore::Size k_dim_;
    infinicore::Size v_dim_;
    infinicore::Size num_rank_k_heads_;
    infinicore::Size num_rank_v_heads_;
    infinicore::Size rank_num_layers_;
    infinicore::DataType dtype_;
    infinicore::Size block_size_;
    infinicore::Size num_blocks_per_layer_;
    // [num_layers, num_blocks, num_rank_k_heads, block_size, k_dim]
    infinicore::Tensor k_caches_;

    // [num_layers, num_blocks, num_rank_v_heads, block_size, v_dim]
    infinicore::Tensor v_caches_;
};

183
} // namespace infinilm::cache