llama-memory-hybrid.h 4.24 KB
Newer Older
1
2
3
4
#pragma once

#include "llama-batch.h"
#include "llama-graph.h"
Daniel Hiltgen's avatar
Daniel Hiltgen committed
5
#include "llama-kv-cache.h"
6
7
8
9
10
11
12
13
14
15
#include "llama-memory.h"
#include "llama-memory-recurrent.h"

#include <memory>
#include <vector>

//
// llama_memory_hybrid
//

Daniel Hiltgen's avatar
Daniel Hiltgen committed
16
// utilizes instances of llama_memory_recurrent and llama_kv_cache to
17
18
19
20
21
22
23
//   support models where each layer may be either attention-based or recurrent

class llama_memory_hybrid : public llama_memory_i {
public:
    llama_memory_hybrid(
        const llama_model & model,
                            /* attn */
Daniel Hiltgen's avatar
Daniel Hiltgen committed
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
                ggml_type   type_k,
                ggml_type   type_v,
                     bool   v_trans,
                 uint32_t   kv_size,
                 uint32_t   n_pad,
                 uint32_t   n_swa,
           llama_swa_type   swa_type,
                            /* recurrent */
                ggml_type   type_r,
                ggml_type   type_s,
                 uint32_t   rs_size,
                            /* common */
                 uint32_t   n_seq_max,
                     bool   offload,
                     bool   unified,
                            /* layer filters */
    const layer_filter_cb & filter_attn = nullptr,
    const layer_filter_cb & filter_recr = nullptr);
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70

    ~llama_memory_hybrid() = default;

    //
    // llama_memory_i
    //

    llama_memory_context_ptr init_batch(
            llama_batch_allocr & balloc,
            uint32_t n_ubatch,
            bool embd_all) override;

    llama_memory_context_ptr init_full() override;

    llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;

    bool get_can_shift() const override;

    void clear(bool data) override;

    bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override;
    void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
    void seq_keep(llama_seq_id seq_id)                                                          override;
    void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos shift) override;
    void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) override;

    llama_pos seq_pos_min(llama_seq_id seq_id) const override;
    llama_pos seq_pos_max(llama_seq_id seq_id) const override;

Daniel Hiltgen's avatar
Daniel Hiltgen committed
71
72
    std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const override;

73
74
    // state write/load

Daniel Hiltgen's avatar
Daniel Hiltgen committed
75
76
    void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
    void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0)       override;
77
78
79
80
81

    //
    // llama_memory_hybrid specific API
    //

Daniel Hiltgen's avatar
Daniel Hiltgen committed
82
    llama_kv_cache * get_mem_attn() const;
83
84
85
86
87
    llama_memory_recurrent * get_mem_recr() const;

private:
    const llama_hparams & hparams;

Daniel Hiltgen's avatar
Daniel Hiltgen committed
88
    const std::unique_ptr<llama_kv_cache> mem_attn;
89
90
91
92
93
    const std::unique_ptr<llama_memory_recurrent> mem_recr;
};

class llama_memory_hybrid_context : public llama_memory_context_i {
public:
Daniel Hiltgen's avatar
Daniel Hiltgen committed
94
    using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125

    // init failure
    explicit llama_memory_hybrid_context(llama_memory_status status);

    // init full
    explicit llama_memory_hybrid_context(llama_memory_hybrid * mem);

    // init update
    explicit llama_memory_hybrid_context(
        llama_memory_hybrid * mem,
              llama_context * lctx,
                       bool   optimize);

    // init success
    llama_memory_hybrid_context(
              llama_memory_hybrid * mem,
                  slot_info_vec_t   sinfos_attn,
        std::vector<llama_ubatch>   ubatches);

    ~llama_memory_hybrid_context() = default;

    bool next()  override;
    bool apply() override;

    llama_memory_status  get_status() const override;
    const llama_ubatch & get_ubatch() const override;

    //
    // llama_memory_hybrid_context
    //

Daniel Hiltgen's avatar
Daniel Hiltgen committed
126
    const llama_kv_cache_context * get_attn() const;
127
128
129
130
131
132
133
134
135
136
137
138
139
    const llama_memory_recurrent_context * get_recr() const;

private:
    // the index of the next ubatch to process
    size_t i_next = 0;

    std::vector<llama_ubatch> ubatches;

    const llama_memory_context_ptr ctx_attn;
    const llama_memory_context_ptr ctx_recr;

    const llama_memory_status status;
};