llama-kv-cache.h 6.39 KB
Newer Older
1
2
3
#pragma once

#include "llama.h"
4
5
#include "llama-io.h"
#include "llama-memory.h"
6
7
8

#include "ggml-cpp.h"

9
#include <functional>
10
11
12
#include <set>
#include <vector>

13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
struct llama_cparams;
struct llama_hparams;
struct llama_ubatch;

struct llama_kv_cache : public llama_memory_i {
    using llama_memory_i::llama_memory_i;

    virtual void restore() = 0; // call if batch processing fails - restores the cache state
    virtual void commit() = 0;  // call after successful batch processing - clears any pending state

    virtual int32_t get_n_tokens()   const = 0;
    virtual int32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache

    virtual bool get_can_shift() const = 0;

    bool get_can_edit() const override { return get_can_shift(); }
};

struct llama_kv_cache_guard {
    llama_kv_cache_guard(llama_kv_cache * kv) : kv(kv) {}

    ~llama_kv_cache_guard() {
        kv->restore();
    }

    void commit() {
        kv->commit();
    }

private:
    llama_kv_cache * kv;
};

// block of KV slots to move when defragging
struct llama_kv_defrag_move {
    uint32_t src;
    uint32_t dst;
    uint32_t len;
};

53
54
struct llama_kv_cell {
    llama_pos pos   = -1;
55
    llama_pos delta =  0;
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
    int32_t   src   = -1; // used by recurrent state models to copy states
    int32_t   tail  = -1;

    std::set<llama_seq_id> seq_id;

    bool has_seq_id(const llama_seq_id & id) const {
        return seq_id.find(id) != seq_id.end();
    }

    bool is_empty() const {
        return seq_id.empty();
    }

    bool is_same_seq(const llama_kv_cell & other) const {
        return seq_id == other.seq_id;
    }
};

// ring-buffer of cached KV data
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
// TODO: pimpl
// TODO: add notion of max sequences
class llama_kv_cache_unified : public llama_kv_cache {
public:
    // can be used to query data from the model if needed
    struct callbacks {
        std::function<ggml_tensor * (uint32_t n_ctx_per_seq, int il)> get_rope_factors;
    };

    llama_kv_cache_unified(
            const llama_hparams & hparams,
            callbacks             cbs);

    virtual ~llama_kv_cache_unified() = default;

    // TODO: become constructor
    bool init(
            const llama_model & model,   // TODO: do not reference the model
          const llama_cparams & cparams,
                    ggml_type   type_k,
                    ggml_type   type_v,
                     uint32_t   kv_size,
                         bool   offload);
98

99
100
    int32_t get_n_tokens()   const override;
    int32_t get_used_cells() const override;
101

102
    size_t total_size() const;
103

104
105
    // TODO: better data structures to reduce the cost of this operation
    llama_pos pos_max() const;
106

107
108
    void clear() override;
    void defrag() override;
109

110
111
    virtual void restore() override;
    virtual void commit() override;
112

113
114
115
116
117
    bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override;
    void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
    void seq_keep(llama_seq_id seq_id) override;
    void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos delta) override;
    void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) override;
118

119
    llama_pos seq_pos_max(llama_seq_id seq_id) const override;
120

121
    bool get_can_shift() const override;
122

123
124
125
126
127
    // find an empty slot of size "n_tokens" in the cache
    // updates the cache head
    // Note: On success, it's important that cache.head points
    // to the first cell of the slot.
    bool find_slot(const llama_ubatch & batch);
128

129
130
    // TODO: maybe not needed
    uint32_t get_padding(const llama_cparams & cparams) const;
131

132
133
    // find how many cells are currently in use
    uint32_t cell_max() const;
134

135
136
    size_t size_k_bytes() const;
    size_t size_v_bytes() const;
137

138
    // defrag
139

140
141
142
    struct {
        std::vector<llama_kv_defrag_move> moves;
    } defrag_info;
143

144
145
    // return true if cells have been moved
    bool defrag_prepare(int32_t n_max_nodes);
146

147
    // commit/restore cache
148

149
150
151
152
    struct slot_range {
        uint32_t c0 = 0; // note: these are cell indices, not sequence positions
        uint32_t c1 = 0;
    };
153

154
155
156
157
    // pending cell updates that are not yet committed
    struct {
        std::vector<slot_range> ranges;
    } pending;
158

159
    // state write/load
160

161
162
    void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const;
    void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1);
163

164
    // members
165

166
    const llama_hparams & hparams;
167

168
    callbacks cbs;
169

170
171
    bool has_shift = false;
    bool do_defrag = false;
172

173
174
    // TODO: remove this and implement llama_kv_cache_recurrent instead
    bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token
175

176
177
    bool v_trans   = true;  // the value tensor is transposed
    bool can_shift = false;
178

179
180
181
182
183
184
    // Note: The value of head isn't only used to optimize searching
    // for a free KV slot. llama_decode_impl also uses it, so it
    // cannot be freely changed after a slot has been allocated.
    uint32_t head = 0;
    uint32_t size = 0;
    uint32_t used = 0; // used cells (i.e. at least one seq_id)
185

186
187
    // computed before each graph build
    uint32_t n = 0;
188

189
    std::vector<llama_kv_cell> cells;
190

191
192
    std::vector<ggml_tensor *> k_l; // per layer
    std::vector<ggml_tensor *> v_l;
193

194
195
196
private:
    ggml_type type_k = GGML_TYPE_F16;
    ggml_type type_v = GGML_TYPE_F16;
197

198
199
    std::vector<ggml_context_ptr>        ctxs;
    std::vector<ggml_backend_buffer_ptr> bufs;
200

201
202
    void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
    void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
203

204
205
206
    bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
    bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
};
207

208
209
210
211
212
// TODO: temporary reusing llama_kv_cache_unified -- implement recurrent cache and simplify llama_kv_cache_unified
//class llama_kv_cache_recurrent : public llama_kv_cache_unified {
//public:
//    using llama_kv_cache_unified::llama_kv_cache_unified;
//};
213

214
215
216
//
// kv cache view
//
217

218
llama_kv_cache_view llama_kv_cache_view_init(const llama_kv_cache & kv, int32_t n_seq_max);
219

220
void llama_kv_cache_view_update(llama_kv_cache_view * view, const llama_kv_cache * kv);