llama-kv-cache.h 12 KB
Newer Older
1
2
3
#pragma once

#include "llama.h"
4
#include "llama-io.h"
5
#include "llama-graph.h"
6
#include "llama-memory.h"
7
8
9
10
11
12

#include "ggml-cpp.h"

#include <set>
#include <vector>

13
14
15
struct llama_cparams;
struct llama_hparams;
struct llama_ubatch;
16
17
18
struct llama_sbatch;
struct llama_model;
struct llama_context;
19
20

struct llama_kv_cache : public llama_memory_i {
21
    virtual ~llama_kv_cache() = default;
22

23
24
    // call if batch processing fails - restores the cache state
    virtual void restore() = 0;
25

26
27
    // call after successful batch processing - clears any pending state
    virtual void commit()  = 0;
28

29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
    // process any pending defrag/shift/etc. operations
    // optionally call once before processing a new batch
    virtual bool update(llama_context & lctx) = 0;

    // schedule a defrag if the fragmentation threshold is exceeded. otherwise, do nothing
    virtual void defrag_sched(float thold) = 0;

    // simulate full cache, used for allocating worst-case compute buffers
    virtual void set_full() = 0;

    //
    // batch processing
    //

    virtual llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) = 0;

    // different KV caches require different batch splitting strategies
    virtual llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const = 0;

    // find an empty slot of size "n_tokens" in the cache
    virtual bool find_slot(const llama_ubatch & batch) = 0;

    // getters
    virtual int32_t   get_n_tokens()   const = 0;
    virtual int32_t   get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
    virtual llama_pos get_pos_max()    const = 0;
    virtual bool      get_can_shift()  const = 0;
56
57

    bool get_can_edit() const override { return get_can_shift(); }
58
59
60
61
62
63
64

    //
    // state write/read
    //

    virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0;
    virtual void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1) = 0;
65
66
};

67
68
69
70
//
// llama_kv_cache_guard
//

71
72
73
74
75
76
77
78
79
80
81
82
83
84
struct llama_kv_cache_guard {
    llama_kv_cache_guard(llama_kv_cache * kv) : kv(kv) {}

    ~llama_kv_cache_guard() {
        kv->restore();
    }

    void commit() {
        kv->commit();
    }

private:
    llama_kv_cache * kv;
};
85
 
86
87
88
89
90
91
92
// block of KV slots to move when defragging
struct llama_kv_defrag_move {
    uint32_t src;
    uint32_t dst;
    uint32_t len;
};

93
94
95
//
// llama_kv_cache_unified
//
96

97
98
99
100
101
102
// TODO: add notion of max sequences
class llama_kv_cache_unified : public llama_kv_cache {
public:
    struct kv_cell {
        llama_pos pos   = -1;
        llama_pos delta =  0;
103

104
        std::set<llama_seq_id> seq_id;
105

106
107
108
        bool has_seq_id(const llama_seq_id & id) const {
            return seq_id.find(id) != seq_id.end();
        }
109

110
111
112
        bool is_empty() const {
            return seq_id.empty();
        }
113

114
115
116
        bool is_same_seq(const kv_cell & other) const {
            return seq_id == other.seq_id;
        }
117
118
    };

119
    static uint32_t get_padding(const llama_cparams & cparams);
120

121
122
    llama_kv_cache_unified(
            const llama_model & model,
123
124
                    ggml_type   type_k,
                    ggml_type   type_v,
125
126
                         bool   v_trans,
                         bool   offload,
127
                     uint32_t   kv_size,
128
                     uint32_t   padding);
129

130
    ~llama_kv_cache_unified() = default;
131

132
133
134
    //
    // llama_memory_i
    //
135

136
    void clear() override;
137

138
139
140
141
142
    bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override;
    void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
    void seq_keep(llama_seq_id seq_id) override;
    void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos delta) override;
    void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) override;
143

144
    llama_pos seq_pos_max(llama_seq_id seq_id) const override;
145

146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
    //
    // llama_kv_cache
    //

    void restore() override;
    void commit()  override;

    bool update(llama_context & ctx) override;

    void defrag_sched(float thold) override;

    void set_full() override;

    llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;

    llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
162

163
164
165
    // updates the cache head
    // Note: On success, it's important that cache.head points
    // to the first cell of the slot.
166
    bool find_slot(const llama_ubatch & batch) override;
167

168
169
    int32_t get_n_tokens()   const override;
    int32_t get_used_cells() const override;
170

171
172
    // TODO: better data structures to reduce the cost of this operation
    llama_pos get_pos_max() const override;
173

174
    bool get_can_shift() const override;
175

176
    // state write/load
177

178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
    void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
    void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1) override;

    // Note: The value of head isn't only used to optimize searching
    // for a free KV slot. llama_decode_impl also uses it, so it
    // cannot be freely changed after a slot has been allocated.
    uint32_t head = 0;
    uint32_t size = 0;
    uint32_t used = 0; // used cells (i.e. at least one seq_id)

    // computed before each graph build
    uint32_t n = 0;

    std::vector<kv_cell> cells;

    std::vector<ggml_tensor *> k_l; // per layer
    std::vector<ggml_tensor *> v_l;

private:
    const llama_model & model;
    const llama_hparams & hparams;

    bool has_shift = false;
    bool do_defrag = false;

    bool v_trans   = true;  // the value tensor is transposed
    bool can_shift = false;

    // required padding
    uint32_t padding = 1;

    ggml_type type_k = GGML_TYPE_F16;
    ggml_type type_v = GGML_TYPE_F16;

    std::vector<ggml_context_ptr>        ctxs;
    std::vector<ggml_backend_buffer_ptr> bufs;

    // defrag
216
217
218
    struct {
        std::vector<llama_kv_defrag_move> moves;
    } defrag_info;
219

220
221
    // return true if cells have been moved
    bool defrag_prepare(int32_t n_max_nodes);
222

223
224
225
226
227
    // commit/restore cache
    struct slot_range {
        uint32_t c0 = 0; // note: these are cell indices, not sequence positions
        uint32_t c1 = 0;
    };
228

229
230
231
232
    // pending cell updates that are not yet committed
    struct {
        std::vector<slot_range> ranges;
    } pending;
233

234
235
    // find how many cells are currently in use
    uint32_t cell_max() const;
236

237
    size_t total_size() const;
238

239
240
    size_t size_k_bytes() const;
    size_t size_v_bytes() const;
241

242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
    ggml_tensor * build_rope_shift(
            const llama_cparams & cparams,
                   ggml_context * ctx,
                    ggml_tensor * cur,
                    ggml_tensor * shift,
                    ggml_tensor * factors,
                          float   freq_base,
                          float   freq_scale) const;

    llm_graph_result_ptr build_graph_shift(
            const llama_cparams & cparams,
                   ggml_context * ctx,
                    ggml_cgraph * gf) const;

    llm_graph_result_ptr build_graph_defrag(
            const llama_cparams & cparams,
                   ggml_context * ctx,
                    ggml_cgraph * gf,
                    const std::vector<llama_kv_defrag_move> & moves) const;
261

262
263
    void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
    void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
264

265
266
267
    bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
    bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
};
268

269
270
271
//
// llama_kv_cache_recurrent
//
272

273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
class llama_kv_cache_recurrent : public llama_kv_cache {
public:
    struct kv_cell {
        llama_pos pos  = -1;
        int32_t   src  = -1; // used to copy states
        int32_t   tail = -1;

        std::set<llama_seq_id> seq_id;

        bool has_seq_id(const llama_seq_id & id) const {
            return seq_id.find(id) != seq_id.end();
        }

        bool is_empty() const {
            return seq_id.empty();
        }

        bool is_same_seq(const kv_cell & other) const {
            return seq_id == other.seq_id;
        }
    };

    llama_kv_cache_recurrent(
            const llama_model & model,
                    ggml_type   type_k,
                    ggml_type   type_v,
                         bool   offload,
                     uint32_t   kv_size);

    ~llama_kv_cache_recurrent() = default;

    //
    // llama_memory_i
    //

    void clear() override;

    bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override;
    void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
    void seq_keep(llama_seq_id seq_id) override;
    void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos delta) override;
    void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) override;

    llama_pos seq_pos_max(llama_seq_id seq_id) const override;

    //
    // llama_kv_cache
    //

    void restore() override;
    void commit()  override;

    bool update(llama_context & lctx) override;

    void defrag_sched(float thold) override;

    void set_full() override;

    llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;

    llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;

    bool find_slot(const llama_ubatch & batch) override;

    int32_t get_n_tokens()   const override;
    int32_t get_used_cells() const override;

    // TODO: better data structures to reduce the cost of this operation
    llama_pos get_pos_max() const override;

    bool get_can_shift() const override;

    // TODO: temporary methods - they are not really const as they do const_cast<>, fix this
    int32_t s_copy(int i) const;
    float   s_mask(int i) const;

    // state write/load

    void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
    void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1) override;
353

354
355
356
357
358
359
    // Note: The value of head isn't only used to optimize searching
    // for a free KV slot. llama_decode_impl also uses it, so it
    // cannot be freely changed after a slot has been allocated.
    uint32_t head = 0;
    uint32_t size = 0;
    uint32_t used = 0; // used cells (i.e. at least one seq_id)
360

361
362
    // computed before each graph build
    uint32_t n = 0;
363

364
    std::vector<kv_cell> cells;
365

366
367
    std::vector<ggml_tensor *> k_l; // per layer
    std::vector<ggml_tensor *> v_l;
368

369
private:
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
    //const llama_model & model;
    const llama_hparams & hparams;

    // commit/restore cache
    // TODO: rework for recurrent cache
    struct slot_range {
        uint32_t c0 = 0; // note: these are cell indices, not sequence positions
        uint32_t c1 = 0;
    };

    // pending cell updates that are not yet committed
    struct {
        std::vector<slot_range> ranges;
    } pending;

385
386
    ggml_type type_k = GGML_TYPE_F16;
    ggml_type type_v = GGML_TYPE_F16;
387

388
389
    std::vector<ggml_context_ptr>        ctxs;
    std::vector<ggml_backend_buffer_ptr> bufs;
390

391
392
393
394
395
396
397
398
    // find how many cells are currently in use
    uint32_t cell_max() const;

    size_t total_size() const;

    size_t size_k_bytes() const;
    size_t size_v_bytes() const;

399
400
    void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
    void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
401

402
403
404
    bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
    bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
};
405
406


407
408
409
//
// kv cache view
//
410

411
llama_kv_cache_view llama_kv_cache_view_init(const llama_kv_cache & kv, int32_t n_seq_max);
412

413
void llama_kv_cache_view_update(llama_kv_cache_view * view, const llama_kv_cache * kv);