0010-ensure-KV-cache-is-fully-defragmented.patch 12.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: jmorganca <jmorganca@gmail.com>
Date: Tue, 15 Apr 2025 14:27:40 -0400
Subject: [PATCH] ensure KV cache is fully defragmented

Sometimes the KV cache requires defragmentation even without
triggering the threshold heuristic. In this case, decoding
will not being able to find a KV cache slot. This is particularly
difficult for the caller to handle if it happens in between
ubatches. To avoid this, we should immediately trigger a defrag.

In addition, a heavily fragmented cache can require more than
max_moves to defragment. Currently, we stop when we hit the limit
but this can leave a cache that still does not have adequate space
even after defragmentation is triggered. Instead, we should do
multiple batches of processing until everything is complete.
---
18
 src/llama-context.cpp  |  18 ++++---
19
20
21
 src/llama-context.h    |   1 +
 src/llama-kv-cache.cpp | 107 ++++++++++++++---------------------------
 src/llama-kv-cache.h   |  12 ++++-
22
 4 files changed, 59 insertions(+), 79 deletions(-)
23

24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
diff --git a/src/llama-context.cpp b/src/llama-context.cpp
index c22687e4..c5948e8f 100644
--- a/src/llama-context.cpp
+++ b/src/llama-context.cpp
@@ -950,9 +950,12 @@ int llama_context::decode(llama_batch & inp_batch) {
 
         // find KV slot
         if (!kv_self->find_slot(ubatch)) {
-            LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
-
-            return 1;
+            kv_self->defrag_sched(-1.0f);
+            kv_self->update(*this);
+            if (!kv_self->find_slot(ubatch)) {
+                LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
+                return 1;
+            }
         }
 
         ggml_backend_sched_reset(sched.get());
@@ -1967,9 +1970,12 @@ void llama_context::opt_epoch_iter(
 
             // TODO: not sure if this is needed
             if (!kv_self->find_slot(ubatch)) {
-                LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
-
-                GGML_ABORT("TODO: handle this error");
+                kv_self->defrag_sched(-1.0f);
+                kv_self->update(*this);
+                if (!kv_self->find_slot(ubatch)) {
+                    LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
+                    GGML_ABORT("TODO: handle this error");
+                }
             }
 
             auto * gf = graph_init();
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
diff --git a/src/llama-context.h b/src/llama-context.h
index c4ab242a..9970dfc6 100644
--- a/src/llama-context.h
+++ b/src/llama-context.h
@@ -5,6 +5,7 @@
 #include "llama-cparams.h"
 #include "llama-graph.h"
 #include "llama-adapter.h"
+#include "llama-kv-cache.h"
 
 #include "ggml-cpp.h"
 #include "ggml-opt.h"
diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp
index a7b0a7eb..1a50c034 100644
--- a/src/llama-kv-cache.cpp
+++ b/src/llama-kv-cache.cpp
@@ -372,8 +372,6 @@ void llama_kv_cache_unified::commit() {
 }
 
 bool llama_kv_cache_unified::update(llama_context & lctx) {
-    bool need_reserve = false;
-
     auto * sched = lctx.get_sched();
 
     if (has_shift) {
@@ -396,8 +394,6 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
             res->set_inputs(nullptr);
 
             lctx.graph_compute(gf, false);
-
-            need_reserve = true;
         }
 
         {
@@ -411,27 +407,36 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
 
     if (do_defrag) {
         LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
+        const uint32_t n_max_nodes = lctx.graph_max_nodes();
+        const uint32_t max_moves = (n_max_nodes - 2*model.hparams.n_layer)/(6*model.hparams.n_layer);
+        if (!defrag_prepare(n_max_nodes)) {
+            LLAMA_LOG_ERROR("%s: failed to prepare defragmentation\n", __func__);
+            return false;
+        }
+
+        for (std::size_t i = 0; i < defrag_info.moves.size(); i += max_moves) {
+            std::vector<struct llama_kv_defrag_move> chunk;
+            auto end = std::min(i + max_moves, defrag_info.moves.size());
+            chunk.assign(defrag_info.moves.begin() + i, defrag_info.moves.begin() + end);
 
-        if (defrag_prepare(lctx.graph_max_nodes())) {
             ggml_backend_sched_reset(sched);
 
             auto * gf = lctx.graph_init();
 
-            auto res = build_graph_defrag(lctx.get_cparams(), lctx.get_ctx_compute(), gf);
+            auto res = build_graph_defrag(lctx.get_cparams(), lctx.get_ctx_compute(), gf, chunk);
 
             ggml_backend_sched_alloc_graph(sched, gf);
 
             res->set_inputs(nullptr);
 
             lctx.graph_compute(gf, false);
-
-            need_reserve = true;
         }
 
         do_defrag = false;
     }
129
 
130
131
132
133
-    return need_reserve;
+    // we never need to reserve a worst case graph
+    return false;
 }
134
 
135
136
137
138
139
140
141
142
143
144
145
 void llama_kv_cache_unified::defrag_sched(float thold) {
@@ -715,11 +720,10 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
 llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
         const llama_cparams & cparams,
                ggml_context * ctx,
-                ggml_cgraph * gf) const {
+                ggml_cgraph * gf,
+                const std::vector<struct llama_kv_defrag_move> & moves) const {
     auto res = std::make_unique<llm_graph_result>();
 
-    const auto & ids = defrag_info.ids;
146
147
148
149
-
 #if 0
     // CPU defrag
     //
150
@@ -791,32 +795,20 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
         ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size());
     }
 #else
-    for (uint32_t i = 0; i < ids.size(); ++i) {
-        const uint32_t id = ids[i];
-
-        if (i == id || id == ids.size()) {
-            continue;
-        }
-
-        uint32_t nm = 1;
-
-        while (i + nm < ids.size() && ids[i + nm] == id + nm) {
-            nm++;
-        }
-
+    for (const auto & move : moves) {
         for (uint32_t il = 0; il < hparams.n_layer; ++il) { // NOLINT
             const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
             const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
 
172
             ggml_tensor * view_k_src = ggml_view_2d(ctx, k_l[il],
173
174
-                    n_embd_k_gqa, nm,
+                    n_embd_k_gqa, move.len,
175
176
177
                     ggml_row_size(k_l[il]->type, n_embd_k_gqa),
-                    ggml_row_size(k_l[il]->type, n_embd_k_gqa*i));
+                    ggml_row_size(k_l[il]->type, n_embd_k_gqa*move.src));
178
 
179
             ggml_tensor * view_k_dst = ggml_view_2d(ctx, k_l[il],
180
181
-                    n_embd_k_gqa, nm,
+                    n_embd_k_gqa, move.len,
182
183
184
                     ggml_row_size(k_l[il]->type, n_embd_k_gqa),
-                    ggml_row_size(k_l[il]->type, n_embd_k_gqa*id));
+                    ggml_row_size(k_l[il]->type, n_embd_k_gqa*move.dst));
185
186
187
 
             ggml_tensor * view_v_src;
             ggml_tensor * view_v_dst;
188
@@ -824,31 +816,29 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
189
190
             if (cparams.flash_attn) {
                 // NOTE: the V cache is not transposed when using flash attention
191
                 view_v_src = ggml_view_2d(ctx, v_l[il],
192
193
-                        n_embd_v_gqa, nm,
+                        n_embd_v_gqa, move.len,
194
195
196
                         ggml_row_size(v_l[il]->type, n_embd_v_gqa),
-                        ggml_row_size(v_l[il]->type, n_embd_v_gqa*i));
+                        ggml_row_size(v_l[il]->type, n_embd_v_gqa*move.dst));
197
 
198
                 view_v_dst = ggml_view_2d(ctx, v_l[il],
199
-                        n_embd_v_gqa, nm,
200
201
202
203
+                        move.len, n_embd_v_gqa,
                         ggml_row_size(v_l[il]->type, n_embd_v_gqa),
-                        ggml_row_size(v_l[il]->type, n_embd_v_gqa*id));
+                        ggml_row_size(v_l[il]->type, move.src));
204
             } else {
205
                 view_v_src = ggml_view_2d(ctx, v_l[il],
206
207
-                        nm, n_embd_v_gqa,
+                        move.len, n_embd_v_gqa,
208
209
210
                         ggml_row_size(v_l[il]->type, size),
-                        ggml_row_size(v_l[il]->type, i));
+                        ggml_row_size(v_l[il]->type, move.src));
211
 
212
                 view_v_dst = ggml_view_2d(ctx, v_l[il],
213
214
-                        nm, n_embd_v_gqa,
+                        move.len, n_embd_v_gqa,
215
216
217
                         ggml_row_size(v_l[il]->type, size),
-                        ggml_row_size(v_l[il]->type, id));
+                        ggml_row_size(v_l[il]->type, move.dst));
218
219
             }
 
220
221
             ggml_build_forward_expand(gf, ggml_cpy(ctx, view_k_src, view_k_dst));
             ggml_build_forward_expand(gf, ggml_cpy(ctx, view_v_src, view_v_dst));
222
223
224
225
226
         }
-
-        i += nm - 1;
     }
 
227
228
     //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
@@ -865,17 +855,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
 
     assert(n_used <= n_kv);
 
-    //const int64_t t_start = ggml_time_us();
-
-    // number of cells moved
-    uint32_t n_moves = 0;
-
-    // each move requires 6*n_layer tensors (see graph_build_kv_self_defrag)
-    //   - source view, destination view, copy operation
-    //   - x2 for keys and values
-    //const uint32_t max_moves = max_nodes()/(6*n_layer);
-    // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516
-    const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer);
+    defrag_info.moves.clear();
 
     // determine which KV cells to move where
     //
247
@@ -883,10 +863,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
248
249
250
251
252
253
254
255
256
257
258
     //
     //  if ids[i] == i || ids[i] == n_kv, then cell i is not moved
     //
-    auto & ids = defrag_info.ids;
-
-    ids.clear();
-    ids.resize(n_kv, n_kv);
+    std::vector<uint32_t> ids(n_kv, n_kv);
 
     for (uint32_t i0 = 0; i0 < n_used; ++i0) {
         const auto & cell0 = cells[i0];
259
@@ -935,19 +912,11 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
         // are we moving a continuous block of memory?
         bool cont = false;
 
-        // should we stop searching for the next move?
-        bool stop = false;
-
         // go back and move the nf cells to the hole
         for (; i1 < n_kv; ++i1) {
             auto & cell1 = cells[i1];
 
             if (cell1.is_empty() || ids[i1] != n_kv) {
-                if (n_moves == max_moves) {
-                    stop = true;
-                    break;
-                }
-
                 cont = false;
                 continue;
             }
279
@@ -963,8 +932,10 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
280
281
282
283
284
285
286
287
288
289
290
             head = n_used;
 
             if (!cont) {
-                n_moves++;
+                defrag_info.moves.push_back({i1, i0 + nf, 1});
                 cont = true;
+            } else {
+                defrag_info.moves.back().len++;
             }
 
             nf++;
291
@@ -974,22 +945,16 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
             }
         }
 
-        if (stop || n_moves == max_moves) {
-            break;
-        }
-
         //LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh);
 
         i0 += nh - 1;
     }
 
-    if (n_moves == 0) {
+    if (defrag_info.moves.size() == 0) {
         return false;
     }
 
309
-    LLAMA_LOG_DEBUG("%s: (tmp log) KV defrag cell moves: %u\n", __func__, n_moves);
310
-
311
-    LLAMA_LOG_DEBUG("%s: expected gf nodes: %u\n", __func__, 6*n_moves*n_layer);
312
313
314
315
316
+    // LLAMA_LOG_DEBUG("(tmp log) KV defrag cell moves: %u\n", n_moves);
 
     return true;
 }
diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h
317
index bf3b4b6a..928b9712 100644
318
319
--- a/src/llama-kv-cache.h
+++ b/src/llama-kv-cache.h
320
321
@@ -82,6 +82,13 @@ struct llama_kv_cache_guard {
 private:
322
323
     llama_kv_cache * kv;
 };
324
+ 
325
326
327
328
329
330
331
+// block of KV slots to move when defragging
+struct llama_kv_defrag_move {
+    uint32_t src;
+    uint32_t dst;
+    uint32_t len;
+};
 
332
333
334
335
336
 //
 // llama_kv_cache_unified
@@ -207,7 +214,7 @@ private:
 
     // defrag
337
338
339
340
341
342
     struct {
-        std::vector<uint32_t> ids;
+        std::vector<llama_kv_defrag_move> moves;
     } defrag_info;
 
     // return true if cells have been moved
343
344
345
346
347
348
349
350
351
352
@@ -249,7 +256,8 @@ private:
     llm_graph_result_ptr build_graph_defrag(
             const llama_cparams & cparams,
                    ggml_context * ctx,
-                    ggml_cgraph * gf) const;
+                    ggml_cgraph * gf,
+                    const std::vector<llama_kv_defrag_move> & moves) const;
 
     void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
     void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;