llama-kv-cache-iswa.cpp 9.93 KB
Newer Older
Daniel Hiltgen's avatar
Daniel Hiltgen committed
1
#include "llama-kv-cache-iswa.h"
2
3
4
5
6
7
8
9
10

#include "llama-impl.h"
#include "llama-batch.h"
#include "llama-model.h"

#include <algorithm>
#include <cassert>

//
Daniel Hiltgen's avatar
Daniel Hiltgen committed
11
// llama_kv_cache_iswa
12
13
//

Daniel Hiltgen's avatar
Daniel Hiltgen committed
14
llama_kv_cache_iswa::llama_kv_cache_iswa(
15
16
17
18
19
20
21
22
23
24
        const llama_model & model,
                ggml_type   type_k,
                ggml_type   type_v,
                     bool   v_trans,
                     bool   offload,
                     bool   swa_full,
                     bool   unified,
                 uint32_t   kv_size,
                 uint32_t   n_seq_max,
                 uint32_t   n_ubatch,
Daniel Hiltgen's avatar
Daniel Hiltgen committed
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
                 uint32_t   n_pad,
    const layer_filter_cb & filter,
    const  layer_reuse_cb & reuse) : hparams(model.hparams), unified(unified) {

    // chain filters
    const layer_filter_cb filter_base = [&](int32_t il) {
        if (filter && !filter(il)) {
            return false;
        }

        return !model.hparams.is_swa(il);
    };

    const layer_filter_cb filter_swa  = [&](int32_t il) {
        if (filter && !filter(il)) {
            return false;
        }

        return  model.hparams.is_swa(il);
    };
45
46
47

    const uint32_t size_base = kv_size;

Daniel Hiltgen's avatar
Daniel Hiltgen committed
48
49
50
    // note: the SWA cache is always padded to 256 for performance
    //       https://github.com/ggml-org/llama.cpp/issues/17037
    uint32_t size_swa = GGML_PAD(std::min(size_base, hparams.n_swa*(unified ? n_seq_max : 1) + n_ubatch), 256);
51
52
53
54
55
56
57
58
59
60
61

    // when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
    if (swa_full) {
        LLAMA_LOG_WARN("%s: using full-size SWA cache (ref: %s)\n",
                __func__, "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");

        size_swa = size_base;
    }

    LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base);

Daniel Hiltgen's avatar
Daniel Hiltgen committed
62
63
    kv_base = std::make_unique<llama_kv_cache>(
            model, type_k, type_v,
64
            v_trans, offload, unified, size_base, n_seq_max, n_pad,
Daniel Hiltgen's avatar
Daniel Hiltgen committed
65
            0, LLAMA_SWA_TYPE_NONE, filter_base, reuse);
66
67
68

    LLAMA_LOG_INFO("%s: creating     SWA KV cache, size = %u cells\n", __func__, size_swa);

Daniel Hiltgen's avatar
Daniel Hiltgen committed
69
70
    kv_swa = std::make_unique<llama_kv_cache>(
            model, type_k, type_v,
71
            v_trans, offload, unified, size_swa, n_seq_max, n_pad,
Daniel Hiltgen's avatar
Daniel Hiltgen committed
72
            hparams.n_swa, hparams.swa_type, filter_swa, reuse);
73
74
}

Daniel Hiltgen's avatar
Daniel Hiltgen committed
75
void llama_kv_cache_iswa::clear(bool data) {
76
77
78
79
    kv_base->clear(data);
    kv_swa ->clear(data);
}

Daniel Hiltgen's avatar
Daniel Hiltgen committed
80
bool llama_kv_cache_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
81
82
83
84
85
86
87
88
    bool res = true;

    res = res & kv_base->seq_rm(seq_id, p0, p1);
    res = res & kv_swa ->seq_rm(seq_id, p0, p1);

    return res;
}

Daniel Hiltgen's avatar
Daniel Hiltgen committed
89
void llama_kv_cache_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
90
91
92
93
    kv_base->seq_cp(seq_id_src, seq_id_dst, p0, p1);
    kv_swa ->seq_cp(seq_id_src, seq_id_dst, p0, p1);
}

Daniel Hiltgen's avatar
Daniel Hiltgen committed
94
void llama_kv_cache_iswa::seq_keep(llama_seq_id seq_id) {
95
96
97
98
    kv_base->seq_keep(seq_id);
    kv_swa ->seq_keep(seq_id);
}

Daniel Hiltgen's avatar
Daniel Hiltgen committed
99
void llama_kv_cache_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
100
101
102
103
    kv_base->seq_add(seq_id, p0, p1, shift);
    kv_swa ->seq_add(seq_id, p0, p1, shift);
}

Daniel Hiltgen's avatar
Daniel Hiltgen committed
104
void llama_kv_cache_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
105
106
107
108
    kv_base->seq_div(seq_id, p0, p1, d);
    kv_swa ->seq_div(seq_id, p0, p1, d);
}

Daniel Hiltgen's avatar
Daniel Hiltgen committed
109
llama_pos llama_kv_cache_iswa::seq_pos_min(llama_seq_id seq_id) const {
110
111
112
113
    // the base cache is a superset of the SWA cache, so we can just check the SWA cache
    return kv_swa->seq_pos_min(seq_id);
}

Daniel Hiltgen's avatar
Daniel Hiltgen committed
114
llama_pos llama_kv_cache_iswa::seq_pos_max(llama_seq_id seq_id) const {
115
116
117
    return kv_swa->seq_pos_max(seq_id);
}

Daniel Hiltgen's avatar
Daniel Hiltgen committed
118
119
120
121
122
123
124
125
126
std::map<ggml_backend_buffer_type_t, size_t> llama_kv_cache_iswa::memory_breakdown() const {
    std::map<ggml_backend_buffer_type_t, size_t> mb = kv_base->memory_breakdown();
    for (const auto & buft_size : kv_swa->memory_breakdown()) {
        mb[buft_size.first] += buft_size.second;
    }
    return mb;
}

llama_memory_context_ptr llama_kv_cache_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
    GGML_UNUSED(embd_all);

    // first try simple split
    do {
        if (!unified) {
            // requires equal splits, so we skip the simple split
            break;
        }

        balloc.split_reset();

        std::vector<llama_ubatch> ubatches;
        while (true) {
            auto ubatch = balloc.split_simple(n_ubatch);

            if (ubatch.n_tokens == 0) {
                break;
            }

            ubatches.push_back(std::move(ubatch)); // NOLINT
        }

        if (balloc.get_n_used() < balloc.get_n_tokens()) {
            // failed to find a suitable split
            break;
        }

        auto sinfos_base = kv_base->prepare(ubatches);
        if (sinfos_base.empty()) {
            break;
        }

        auto sinfos_swa = kv_swa->prepare(ubatches);
        if (sinfos_swa.empty()) {
            break;
        }

        assert(sinfos_base.size() == sinfos_swa.size());

Daniel Hiltgen's avatar
Daniel Hiltgen committed
166
        return std::make_unique<llama_kv_cache_iswa_context>(
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
                this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
    } while (false);

    // if it fails, try equal split
    do {
        balloc.split_reset();

        std::vector<llama_ubatch> ubatches;
        while (true) {
            auto ubatch = balloc.split_equal(n_ubatch, !unified);

            if (ubatch.n_tokens == 0) {
                break;
            }

            ubatches.push_back(std::move(ubatch)); // NOLINT
        }

        if (balloc.get_n_used() < balloc.get_n_tokens()) {
            // failed to find a suitable split
            break;
        }

        auto sinfos_base = kv_base->prepare(ubatches);
        if (sinfos_base.empty()) {
            break;
        }

        auto sinfos_swa = kv_swa->prepare(ubatches);
        if (sinfos_swa.empty()) {
            break;
        }

        assert(sinfos_base.size() == sinfos_swa.size());

Daniel Hiltgen's avatar
Daniel Hiltgen committed
202
        return std::make_unique<llama_kv_cache_iswa_context>(
203
204
205
206
207
208
                this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
    } while (false);

    // TODO: if we fail again, we should attempt different splitting strategies
    //       but to do that properly, we first have to refactor the batches to be more flexible

Daniel Hiltgen's avatar
Daniel Hiltgen committed
209
    return std::make_unique<llama_kv_cache_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
210
211
}

Daniel Hiltgen's avatar
Daniel Hiltgen committed
212
213
llama_memory_context_ptr llama_kv_cache_iswa::init_full() {
    return std::make_unique<llama_kv_cache_iswa_context>(this);
214
215
}

Daniel Hiltgen's avatar
Daniel Hiltgen committed
216
217
llama_memory_context_ptr llama_kv_cache_iswa::init_update(llama_context * lctx, bool optimize) {
    return std::make_unique<llama_kv_cache_iswa_context>(this, lctx, optimize);
218
219
}

Daniel Hiltgen's avatar
Daniel Hiltgen committed
220
bool llama_kv_cache_iswa::get_can_shift() const {
221
222
223
    return kv_base->get_size() == kv_swa->get_size();
}

Daniel Hiltgen's avatar
Daniel Hiltgen committed
224
void llama_kv_cache_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
225
    if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) {
Daniel Hiltgen's avatar
Daniel Hiltgen committed
226
227
228
229
        kv_base->state_write(io, seq_id, flags);
    }

    kv_swa->state_write(io, seq_id, flags);
230
231
}

Daniel Hiltgen's avatar
Daniel Hiltgen committed
232
void llama_kv_cache_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
233
    if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) {
Daniel Hiltgen's avatar
Daniel Hiltgen committed
234
235
236
237
        kv_base->state_read(io, seq_id, flags);
    }

    kv_swa->state_read(io, seq_id, flags);
238
239
}

Daniel Hiltgen's avatar
Daniel Hiltgen committed
240
llama_kv_cache * llama_kv_cache_iswa::get_base() const {
241
242
243
    return kv_base.get();
}

Daniel Hiltgen's avatar
Daniel Hiltgen committed
244
llama_kv_cache * llama_kv_cache_iswa::get_swa() const {
245
246
247
248
    return kv_swa.get();
}

//
Daniel Hiltgen's avatar
Daniel Hiltgen committed
249
// llama_kv_cache_iswa_context
250
251
//

Daniel Hiltgen's avatar
Daniel Hiltgen committed
252
llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(llama_memory_status status) : status(status) {}
253

Daniel Hiltgen's avatar
Daniel Hiltgen committed
254
255
llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(
        llama_kv_cache_iswa * kv) :
256
257
258
259
260
    ctx_base(kv->get_base()->init_full()),
    ctx_swa (kv->get_swa ()->init_full()),
    status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
}

Daniel Hiltgen's avatar
Daniel Hiltgen committed
261
262
llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(
        llama_kv_cache_iswa * kv,
263
264
265
266
267
268
269
        llama_context * lctx,
        bool optimize) :
    ctx_base(kv->get_base()->init_update(lctx, optimize)),
    ctx_swa (kv->get_swa ()->init_update(lctx, optimize)),
    status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
}

Daniel Hiltgen's avatar
Daniel Hiltgen committed
270
271
llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(
        llama_kv_cache_iswa * kv,
272
273
274
275
276
        slot_info_vec_t sinfos_base,
        slot_info_vec_t sinfos_swa,
        std::vector<llama_ubatch> ubatches) :
    ubatches(std::move(ubatches)),
    // note: here we copy the ubatches. not sure if this is ideal
Daniel Hiltgen's avatar
Daniel Hiltgen committed
277
278
    ctx_base(new llama_kv_cache_context(kv->get_base(), std::move(sinfos_base), this->ubatches)),
    ctx_swa (new llama_kv_cache_context(kv->get_swa (), std::move(sinfos_swa),  this->ubatches)),
279
280
281
    status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
}

Daniel Hiltgen's avatar
Daniel Hiltgen committed
282
llama_kv_cache_iswa_context:: ~llama_kv_cache_iswa_context() = default;
283

Daniel Hiltgen's avatar
Daniel Hiltgen committed
284
bool llama_kv_cache_iswa_context::next() {
285
286
287
288
289
290
291
292
293
294
295
296
    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);

    ctx_base->next();
    ctx_swa ->next();

    if (++i_next >= ubatches.size()) {
        return false;
    }

    return true;
}

Daniel Hiltgen's avatar
Daniel Hiltgen committed
297
bool llama_kv_cache_iswa_context::apply() {
298
299
300
301
302
303
304
305
306
307
    assert(!llama_memory_status_is_fail(status));

    bool res = true;

    res = res & ctx_base->apply();
    res = res & ctx_swa ->apply();

    return res;
}

Daniel Hiltgen's avatar
Daniel Hiltgen committed
308
llama_memory_status llama_kv_cache_iswa_context::get_status() const {
309
310
311
    return status;
}

Daniel Hiltgen's avatar
Daniel Hiltgen committed
312
const llama_ubatch & llama_kv_cache_iswa_context::get_ubatch() const {
313
314
315
316
317
    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);

    return ubatches[i_next];
}

Daniel Hiltgen's avatar
Daniel Hiltgen committed
318
const llama_kv_cache_context * llama_kv_cache_iswa_context::get_base() const {
319
320
    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);

Daniel Hiltgen's avatar
Daniel Hiltgen committed
321
    return static_cast<const llama_kv_cache_context *>(ctx_base.get());
322
323
}

Daniel Hiltgen's avatar
Daniel Hiltgen committed
324
const llama_kv_cache_context * llama_kv_cache_iswa_context::get_swa()  const {
325
326
    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);

Daniel Hiltgen's avatar
Daniel Hiltgen committed
327
    return static_cast<const llama_kv_cache_context *>(ctx_swa.get());
328
}