llama-kv-cache-iswa.cpp 9.8 KB
Newer Older
Daniel Hiltgen's avatar
Daniel Hiltgen committed
1
#include "llama-kv-cache-iswa.h"
2
3
4
5
6
7
8
9
10

#include "llama-impl.h"
#include "llama-batch.h"
#include "llama-model.h"

#include <algorithm>
#include <cassert>

//
Daniel Hiltgen's avatar
Daniel Hiltgen committed
11
// llama_kv_cache_iswa
12
13
//

Daniel Hiltgen's avatar
Daniel Hiltgen committed
14
llama_kv_cache_iswa::llama_kv_cache_iswa(
15
16
17
18
19
20
21
22
23
24
        const llama_model & model,
                ggml_type   type_k,
                ggml_type   type_v,
                     bool   v_trans,
                     bool   offload,
                     bool   swa_full,
                     bool   unified,
                 uint32_t   kv_size,
                 uint32_t   n_seq_max,
                 uint32_t   n_ubatch,
Daniel Hiltgen's avatar
Daniel Hiltgen committed
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
                 uint32_t   n_pad,
    const layer_filter_cb & filter,
    const  layer_reuse_cb & reuse) : hparams(model.hparams), unified(unified) {

    // chain filters
    const layer_filter_cb filter_base = [&](int32_t il) {
        if (filter && !filter(il)) {
            return false;
        }

        return !model.hparams.is_swa(il);
    };

    const layer_filter_cb filter_swa  = [&](int32_t il) {
        if (filter && !filter(il)) {
            return false;
        }

        return  model.hparams.is_swa(il);
    };
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59

    const uint32_t size_base = kv_size;

    uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*(unified ? n_seq_max : 1) + n_ubatch, n_pad));

    // when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
    if (swa_full) {
        LLAMA_LOG_WARN("%s: using full-size SWA cache (ref: %s)\n",
                __func__, "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");

        size_swa = size_base;
    }

    LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base);

Daniel Hiltgen's avatar
Daniel Hiltgen committed
60
61
    kv_base = std::make_unique<llama_kv_cache>(
            model, type_k, type_v,
62
            v_trans, offload, unified, size_base, n_seq_max, n_pad,
Daniel Hiltgen's avatar
Daniel Hiltgen committed
63
            0, LLAMA_SWA_TYPE_NONE, filter_base, reuse);
64
65
66

    LLAMA_LOG_INFO("%s: creating     SWA KV cache, size = %u cells\n", __func__, size_swa);

Daniel Hiltgen's avatar
Daniel Hiltgen committed
67
68
    kv_swa = std::make_unique<llama_kv_cache>(
            model, type_k, type_v,
69
            v_trans, offload, unified, size_swa, n_seq_max, n_pad,
Daniel Hiltgen's avatar
Daniel Hiltgen committed
70
            hparams.n_swa, hparams.swa_type, filter_swa, reuse);
71
72
}

Daniel Hiltgen's avatar
Daniel Hiltgen committed
73
void llama_kv_cache_iswa::clear(bool data) {
74
75
76
77
    kv_base->clear(data);
    kv_swa ->clear(data);
}

Daniel Hiltgen's avatar
Daniel Hiltgen committed
78
bool llama_kv_cache_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
79
80
81
82
83
84
85
86
    bool res = true;

    res = res & kv_base->seq_rm(seq_id, p0, p1);
    res = res & kv_swa ->seq_rm(seq_id, p0, p1);

    return res;
}

Daniel Hiltgen's avatar
Daniel Hiltgen committed
87
void llama_kv_cache_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
88
89
90
91
    kv_base->seq_cp(seq_id_src, seq_id_dst, p0, p1);
    kv_swa ->seq_cp(seq_id_src, seq_id_dst, p0, p1);
}

Daniel Hiltgen's avatar
Daniel Hiltgen committed
92
void llama_kv_cache_iswa::seq_keep(llama_seq_id seq_id) {
93
94
95
96
    kv_base->seq_keep(seq_id);
    kv_swa ->seq_keep(seq_id);
}

Daniel Hiltgen's avatar
Daniel Hiltgen committed
97
void llama_kv_cache_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
98
99
100
101
    kv_base->seq_add(seq_id, p0, p1, shift);
    kv_swa ->seq_add(seq_id, p0, p1, shift);
}

Daniel Hiltgen's avatar
Daniel Hiltgen committed
102
void llama_kv_cache_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
103
104
105
106
    kv_base->seq_div(seq_id, p0, p1, d);
    kv_swa ->seq_div(seq_id, p0, p1, d);
}

Daniel Hiltgen's avatar
Daniel Hiltgen committed
107
llama_pos llama_kv_cache_iswa::seq_pos_min(llama_seq_id seq_id) const {
108
109
110
111
    // the base cache is a superset of the SWA cache, so we can just check the SWA cache
    return kv_swa->seq_pos_min(seq_id);
}

Daniel Hiltgen's avatar
Daniel Hiltgen committed
112
llama_pos llama_kv_cache_iswa::seq_pos_max(llama_seq_id seq_id) const {
113
114
115
    return kv_swa->seq_pos_max(seq_id);
}

Daniel Hiltgen's avatar
Daniel Hiltgen committed
116
117
118
119
120
121
122
123
124
std::map<ggml_backend_buffer_type_t, size_t> llama_kv_cache_iswa::memory_breakdown() const {
    std::map<ggml_backend_buffer_type_t, size_t> mb = kv_base->memory_breakdown();
    for (const auto & buft_size : kv_swa->memory_breakdown()) {
        mb[buft_size.first] += buft_size.second;
    }
    return mb;
}

llama_memory_context_ptr llama_kv_cache_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
    GGML_UNUSED(embd_all);

    // first try simple split
    do {
        if (!unified) {
            // requires equal splits, so we skip the simple split
            break;
        }

        balloc.split_reset();

        std::vector<llama_ubatch> ubatches;
        while (true) {
            auto ubatch = balloc.split_simple(n_ubatch);

            if (ubatch.n_tokens == 0) {
                break;
            }

            ubatches.push_back(std::move(ubatch)); // NOLINT
        }

        if (balloc.get_n_used() < balloc.get_n_tokens()) {
            // failed to find a suitable split
            break;
        }

        auto sinfos_base = kv_base->prepare(ubatches);
        if (sinfos_base.empty()) {
            break;
        }

        auto sinfos_swa = kv_swa->prepare(ubatches);
        if (sinfos_swa.empty()) {
            break;
        }

        assert(sinfos_base.size() == sinfos_swa.size());

Daniel Hiltgen's avatar
Daniel Hiltgen committed
164
        return std::make_unique<llama_kv_cache_iswa_context>(
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
                this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
    } while (false);

    // if it fails, try equal split
    do {
        balloc.split_reset();

        std::vector<llama_ubatch> ubatches;
        while (true) {
            auto ubatch = balloc.split_equal(n_ubatch, !unified);

            if (ubatch.n_tokens == 0) {
                break;
            }

            ubatches.push_back(std::move(ubatch)); // NOLINT
        }

        if (balloc.get_n_used() < balloc.get_n_tokens()) {
            // failed to find a suitable split
            break;
        }

        auto sinfos_base = kv_base->prepare(ubatches);
        if (sinfos_base.empty()) {
            break;
        }

        auto sinfos_swa = kv_swa->prepare(ubatches);
        if (sinfos_swa.empty()) {
            break;
        }

        assert(sinfos_base.size() == sinfos_swa.size());

Daniel Hiltgen's avatar
Daniel Hiltgen committed
200
        return std::make_unique<llama_kv_cache_iswa_context>(
201
202
203
204
205
206
                this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
    } while (false);

    // TODO: if we fail again, we should attempt different splitting strategies
    //       but to do that properly, we first have to refactor the batches to be more flexible

Daniel Hiltgen's avatar
Daniel Hiltgen committed
207
    return std::make_unique<llama_kv_cache_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
208
209
}

Daniel Hiltgen's avatar
Daniel Hiltgen committed
210
211
llama_memory_context_ptr llama_kv_cache_iswa::init_full() {
    return std::make_unique<llama_kv_cache_iswa_context>(this);
212
213
}

Daniel Hiltgen's avatar
Daniel Hiltgen committed
214
215
llama_memory_context_ptr llama_kv_cache_iswa::init_update(llama_context * lctx, bool optimize) {
    return std::make_unique<llama_kv_cache_iswa_context>(this, lctx, optimize);
216
217
}

Daniel Hiltgen's avatar
Daniel Hiltgen committed
218
bool llama_kv_cache_iswa::get_can_shift() const {
219
220
221
    return kv_base->get_size() == kv_swa->get_size();
}

Daniel Hiltgen's avatar
Daniel Hiltgen committed
222
void llama_kv_cache_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
223
    if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) {
Daniel Hiltgen's avatar
Daniel Hiltgen committed
224
225
226
227
        kv_base->state_write(io, seq_id, flags);
    }

    kv_swa->state_write(io, seq_id, flags);
228
229
}

Daniel Hiltgen's avatar
Daniel Hiltgen committed
230
void llama_kv_cache_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
231
    if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) {
Daniel Hiltgen's avatar
Daniel Hiltgen committed
232
233
234
235
        kv_base->state_read(io, seq_id, flags);
    }

    kv_swa->state_read(io, seq_id, flags);
236
237
}

Daniel Hiltgen's avatar
Daniel Hiltgen committed
238
llama_kv_cache * llama_kv_cache_iswa::get_base() const {
239
240
241
    return kv_base.get();
}

Daniel Hiltgen's avatar
Daniel Hiltgen committed
242
llama_kv_cache * llama_kv_cache_iswa::get_swa() const {
243
244
245
246
    return kv_swa.get();
}

//
Daniel Hiltgen's avatar
Daniel Hiltgen committed
247
// llama_kv_cache_iswa_context
248
249
//

Daniel Hiltgen's avatar
Daniel Hiltgen committed
250
llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(llama_memory_status status) : status(status) {}
251

Daniel Hiltgen's avatar
Daniel Hiltgen committed
252
253
llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(
        llama_kv_cache_iswa * kv) :
254
255
256
257
258
    ctx_base(kv->get_base()->init_full()),
    ctx_swa (kv->get_swa ()->init_full()),
    status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
}

Daniel Hiltgen's avatar
Daniel Hiltgen committed
259
260
llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(
        llama_kv_cache_iswa * kv,
261
262
263
264
265
266
267
        llama_context * lctx,
        bool optimize) :
    ctx_base(kv->get_base()->init_update(lctx, optimize)),
    ctx_swa (kv->get_swa ()->init_update(lctx, optimize)),
    status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
}

Daniel Hiltgen's avatar
Daniel Hiltgen committed
268
269
llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(
        llama_kv_cache_iswa * kv,
270
271
272
273
274
        slot_info_vec_t sinfos_base,
        slot_info_vec_t sinfos_swa,
        std::vector<llama_ubatch> ubatches) :
    ubatches(std::move(ubatches)),
    // note: here we copy the ubatches. not sure if this is ideal
Daniel Hiltgen's avatar
Daniel Hiltgen committed
275
276
    ctx_base(new llama_kv_cache_context(kv->get_base(), std::move(sinfos_base), this->ubatches)),
    ctx_swa (new llama_kv_cache_context(kv->get_swa (), std::move(sinfos_swa),  this->ubatches)),
277
278
279
    status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
}

Daniel Hiltgen's avatar
Daniel Hiltgen committed
280
llama_kv_cache_iswa_context:: ~llama_kv_cache_iswa_context() = default;
281

Daniel Hiltgen's avatar
Daniel Hiltgen committed
282
bool llama_kv_cache_iswa_context::next() {
283
284
285
286
287
288
289
290
291
292
293
294
    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);

    ctx_base->next();
    ctx_swa ->next();

    if (++i_next >= ubatches.size()) {
        return false;
    }

    return true;
}

Daniel Hiltgen's avatar
Daniel Hiltgen committed
295
bool llama_kv_cache_iswa_context::apply() {
296
297
298
299
300
301
302
303
304
305
    assert(!llama_memory_status_is_fail(status));

    bool res = true;

    res = res & ctx_base->apply();
    res = res & ctx_swa ->apply();

    return res;
}

Daniel Hiltgen's avatar
Daniel Hiltgen committed
306
llama_memory_status llama_kv_cache_iswa_context::get_status() const {
307
308
309
    return status;
}

Daniel Hiltgen's avatar
Daniel Hiltgen committed
310
const llama_ubatch & llama_kv_cache_iswa_context::get_ubatch() const {
311
312
313
314
315
    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);

    return ubatches[i_next];
}

Daniel Hiltgen's avatar
Daniel Hiltgen committed
316
const llama_kv_cache_context * llama_kv_cache_iswa_context::get_base() const {
317
318
    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);

Daniel Hiltgen's avatar
Daniel Hiltgen committed
319
    return static_cast<const llama_kv_cache_context *>(ctx_base.get());
320
321
}

Daniel Hiltgen's avatar
Daniel Hiltgen committed
322
const llama_kv_cache_context * llama_kv_cache_iswa_context::get_swa()  const {
323
324
    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);

Daniel Hiltgen's avatar
Daniel Hiltgen committed
325
    return static_cast<const llama_kv_cache_context *>(ctx_swa.get());
326
}