0008-ensure-KV-cache-is-fully-defragmented.patch 12.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: jmorganca <jmorganca@gmail.com>
Date: Tue, 15 Apr 2025 14:27:40 -0400
Subject: [PATCH] ensure KV cache is fully defragmented

Sometimes the KV cache requires defragmentation even without
triggering the threshold heuristic. In this case, decoding
will not being able to find a KV cache slot. This is particularly
difficult for the caller to handle if it happens in between
ubatches. To avoid this, we should immediately trigger a defrag.

In addition, a heavily fragmented cache can require more than
max_moves to defragment. Currently, we stop when we hit the limit
but this can leave a cache that still does not have adequate space
even after defragmentation is triggered. Instead, we should do
multiple batches of processing until everything is complete.
---
18
 src/llama-context.cpp  |  18 ++++---
19
20
21
 src/llama-context.h    |   1 +
 src/llama-kv-cache.cpp | 107 ++++++++++++++---------------------------
 src/llama-kv-cache.h   |  12 ++++-
22
 4 files changed, 59 insertions(+), 79 deletions(-)
23

24
diff --git a/src/llama-context.cpp b/src/llama-context.cpp
Michael Yang's avatar
Michael Yang committed
25
index dca22d8b..1f3a3956 100644
26
27
--- a/src/llama-context.cpp
+++ b/src/llama-context.cpp
Michael Yang's avatar
Michael Yang committed
28
@@ -947,9 +947,12 @@ int llama_context::decode(llama_batch & inp_batch) {
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
 
         // find KV slot
         if (!kv_self->find_slot(ubatch)) {
-            LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
-
-            return 1;
+            kv_self->defrag_sched(-1.0f);
+            kv_self->update(*this);
+            if (!kv_self->find_slot(ubatch)) {
+                LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
+                return 1;
+            }
         }
 
         ggml_backend_sched_reset(sched.get());
Michael Yang's avatar
Michael Yang committed
44
@@ -1965,9 +1968,12 @@ void llama_context::opt_epoch_iter(
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
 
             // TODO: not sure if this is needed
             if (!kv_self->find_slot(ubatch)) {
-                LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
-
-                GGML_ABORT("TODO: handle this error");
+                kv_self->defrag_sched(-1.0f);
+                kv_self->update(*this);
+                if (!kv_self->find_slot(ubatch)) {
+                    LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
+                    GGML_ABORT("TODO: handle this error");
+                }
             }
 
             auto * gf = graph_init();
60
diff --git a/src/llama-context.h b/src/llama-context.h
61
index c0ceacb1..0264e937 100644
62
63
64
65
66
67
68
69
70
71
72
--- a/src/llama-context.h
+++ b/src/llama-context.h
@@ -5,6 +5,7 @@
 #include "llama-cparams.h"
 #include "llama-graph.h"
 #include "llama-adapter.h"
+#include "llama-kv-cache.h"
 
 #include "ggml-cpp.h"
 #include "ggml-opt.h"
diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp
73
index 3dcad65b..60e67b03 100644
74
75
--- a/src/llama-kv-cache.cpp
+++ b/src/llama-kv-cache.cpp
76
@@ -364,8 +364,6 @@ void llama_kv_cache_unified::commit() {
77
78
79
80
81
82
83
84
 }
 
 bool llama_kv_cache_unified::update(llama_context & lctx) {
-    bool need_reserve = false;
-
     auto * sched = lctx.get_sched();
 
     if (has_shift) {
85
@@ -388,8 +386,6 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
86
87
88
89
90
91
92
93
             res->set_inputs(nullptr);
 
             lctx.graph_compute(gf, false);
-
-            need_reserve = true;
         }
 
         {
94
@@ -403,27 +399,36 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
 
     if (do_defrag) {
         LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
+        const uint32_t n_max_nodes = lctx.graph_max_nodes();
+        const uint32_t max_moves = (n_max_nodes - 2*model.hparams.n_layer)/(6*model.hparams.n_layer);
+        if (!defrag_prepare(n_max_nodes)) {
+            LLAMA_LOG_ERROR("%s: failed to prepare defragmentation\n", __func__);
+            return false;
+        }
+
+        for (std::size_t i = 0; i < defrag_info.moves.size(); i += max_moves) {
+            std::vector<struct llama_kv_defrag_move> chunk;
+            auto end = std::min(i + max_moves, defrag_info.moves.size());
+            chunk.assign(defrag_info.moves.begin() + i, defrag_info.moves.begin() + end);
 
-        if (defrag_prepare(lctx.graph_max_nodes())) {
             ggml_backend_sched_reset(sched);
 
             auto * gf = lctx.graph_init();
 
-            auto res = build_graph_defrag(lctx.get_cparams(), lctx.get_ctx_compute(), gf);
+            auto res = build_graph_defrag(lctx.get_cparams(), lctx.get_ctx_compute(), gf, chunk);
 
             ggml_backend_sched_alloc_graph(sched, gf);
 
             res->set_inputs(nullptr);
 
             lctx.graph_compute(gf, false);
-
-            need_reserve = true;
         }
 
         do_defrag = false;
     }
129
 
130
131
132
133
-    return need_reserve;
+    // we never need to reserve a worst case graph
+    return false;
 }
134
 
135
 void llama_kv_cache_unified::defrag_sched(float thold) {
136
@@ -707,11 +712,10 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
137
138
139
140
141
142
143
144
145
 llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
         const llama_cparams & cparams,
                ggml_context * ctx,
-                ggml_cgraph * gf) const {
+                ggml_cgraph * gf,
+                const std::vector<struct llama_kv_defrag_move> & moves) const {
     auto res = std::make_unique<llm_graph_result>();
 
-    const auto & ids = defrag_info.ids;
146
147
148
149
-
 #if 0
     // CPU defrag
     //
150
@@ -783,32 +787,20 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
         ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size());
     }
 #else
-    for (uint32_t i = 0; i < ids.size(); ++i) {
-        const uint32_t id = ids[i];
-
-        if (i == id || id == ids.size()) {
-            continue;
-        }
-
-        uint32_t nm = 1;
-
-        while (i + nm < ids.size() && ids[i + nm] == id + nm) {
-            nm++;
-        }
-
+    for (const auto & move : moves) {
         for (uint32_t il = 0; il < hparams.n_layer; ++il) { // NOLINT
             const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
             const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
 
172
             ggml_tensor * view_k_src = ggml_view_2d(ctx, k_l[il],
173
174
-                    n_embd_k_gqa, nm,
+                    n_embd_k_gqa, move.len,
175
176
177
                     ggml_row_size(k_l[il]->type, n_embd_k_gqa),
-                    ggml_row_size(k_l[il]->type, n_embd_k_gqa*i));
+                    ggml_row_size(k_l[il]->type, n_embd_k_gqa*move.src));
178
 
179
             ggml_tensor * view_k_dst = ggml_view_2d(ctx, k_l[il],
180
181
-                    n_embd_k_gqa, nm,
+                    n_embd_k_gqa, move.len,
182
183
184
                     ggml_row_size(k_l[il]->type, n_embd_k_gqa),
-                    ggml_row_size(k_l[il]->type, n_embd_k_gqa*id));
+                    ggml_row_size(k_l[il]->type, n_embd_k_gqa*move.dst));
185
186
187
 
             ggml_tensor * view_v_src;
             ggml_tensor * view_v_dst;
188
@@ -816,31 +808,29 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
189
190
             if (cparams.flash_attn) {
                 // NOTE: the V cache is not transposed when using flash attention
191
                 view_v_src = ggml_view_2d(ctx, v_l[il],
192
193
-                        n_embd_v_gqa, nm,
+                        n_embd_v_gqa, move.len,
194
195
196
                         ggml_row_size(v_l[il]->type, n_embd_v_gqa),
-                        ggml_row_size(v_l[il]->type, n_embd_v_gqa*i));
+                        ggml_row_size(v_l[il]->type, n_embd_v_gqa*move.dst));
197
 
198
                 view_v_dst = ggml_view_2d(ctx, v_l[il],
199
-                        n_embd_v_gqa, nm,
200
201
202
203
+                        move.len, n_embd_v_gqa,
                         ggml_row_size(v_l[il]->type, n_embd_v_gqa),
-                        ggml_row_size(v_l[il]->type, n_embd_v_gqa*id));
+                        ggml_row_size(v_l[il]->type, move.src));
204
             } else {
205
                 view_v_src = ggml_view_2d(ctx, v_l[il],
206
207
-                        nm, n_embd_v_gqa,
+                        move.len, n_embd_v_gqa,
208
209
210
                         ggml_row_size(v_l[il]->type, size),
-                        ggml_row_size(v_l[il]->type, i));
+                        ggml_row_size(v_l[il]->type, move.src));
211
 
212
                 view_v_dst = ggml_view_2d(ctx, v_l[il],
213
214
-                        nm, n_embd_v_gqa,
+                        move.len, n_embd_v_gqa,
215
216
217
                         ggml_row_size(v_l[il]->type, size),
-                        ggml_row_size(v_l[il]->type, id));
+                        ggml_row_size(v_l[il]->type, move.dst));
218
219
             }
 
220
221
             ggml_build_forward_expand(gf, ggml_cpy(ctx, view_k_src, view_k_dst));
             ggml_build_forward_expand(gf, ggml_cpy(ctx, view_v_src, view_v_dst));
222
223
224
225
226
         }
-
-        i += nm - 1;
     }
 
227
     //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
228
@@ -857,17 +847,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
 
     assert(n_used <= n_kv);
 
-    //const int64_t t_start = ggml_time_us();
-
-    // number of cells moved
-    uint32_t n_moves = 0;
-
-    // each move requires 6*n_layer tensors (see graph_build_kv_self_defrag)
-    //   - source view, destination view, copy operation
-    //   - x2 for keys and values
-    //const uint32_t max_moves = max_nodes()/(6*n_layer);
-    // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516
-    const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer);
+    defrag_info.moves.clear();
 
     // determine which KV cells to move where
     //
247
@@ -875,10 +855,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
248
249
250
251
252
253
254
255
256
257
258
     //
     //  if ids[i] == i || ids[i] == n_kv, then cell i is not moved
     //
-    auto & ids = defrag_info.ids;
-
-    ids.clear();
-    ids.resize(n_kv, n_kv);
+    std::vector<uint32_t> ids(n_kv, n_kv);
 
     for (uint32_t i0 = 0; i0 < n_used; ++i0) {
         const auto & cell0 = cells[i0];
259
@@ -927,19 +904,11 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
         // are we moving a continuous block of memory?
         bool cont = false;
 
-        // should we stop searching for the next move?
-        bool stop = false;
-
         // go back and move the nf cells to the hole
         for (; i1 < n_kv; ++i1) {
             auto & cell1 = cells[i1];
 
             if (cell1.is_empty() || ids[i1] != n_kv) {
-                if (n_moves == max_moves) {
-                    stop = true;
-                    break;
-                }
-
                 cont = false;
                 continue;
             }
279
@@ -955,8 +924,10 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
280
281
282
283
284
285
286
287
288
289
290
             head = n_used;
 
             if (!cont) {
-                n_moves++;
+                defrag_info.moves.push_back({i1, i0 + nf, 1});
                 cont = true;
+            } else {
+                defrag_info.moves.back().len++;
             }
 
             nf++;
291
@@ -966,22 +937,16 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
             }
         }
 
-        if (stop || n_moves == max_moves) {
-            break;
-        }
-
         //LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh);
 
         i0 += nh - 1;
     }
 
-    if (n_moves == 0) {
+    if (defrag_info.moves.size() == 0) {
         return false;
     }
 
309
-    LLAMA_LOG_DEBUG("%s: (tmp log) KV defrag cell moves: %u\n", __func__, n_moves);
310
-
311
-    LLAMA_LOG_DEBUG("%s: expected gf nodes: %u\n", __func__, 6*n_moves*n_layer);
312
313
314
315
316
+    // LLAMA_LOG_DEBUG("(tmp log) KV defrag cell moves: %u\n", n_moves);
 
     return true;
 }
diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h
317
index bf3b4b6a..928b9712 100644
318
319
--- a/src/llama-kv-cache.h
+++ b/src/llama-kv-cache.h
320
321
@@ -82,6 +82,13 @@ struct llama_kv_cache_guard {
 private:
322
323
     llama_kv_cache * kv;
 };
324
+ 
325
326
327
328
329
330
331
+// block of KV slots to move when defragging
+struct llama_kv_defrag_move {
+    uint32_t src;
+    uint32_t dst;
+    uint32_t len;
+};
 
332
333
334
335
336
 //
 // llama_kv_cache_unified
@@ -207,7 +214,7 @@ private:
 
     // defrag
337
338
339
340
341
342
     struct {
-        std::vector<uint32_t> ids;
+        std::vector<llama_kv_defrag_move> moves;
     } defrag_info;
 
     // return true if cells have been moved
343
344
345
346
347
348
349
350
351
352
@@ -249,7 +256,8 @@ private:
     llm_graph_result_ptr build_graph_defrag(
             const llama_cparams & cparams,
                    ggml_context * ctx,
-                    ggml_cgraph * gf) const;
+                    ggml_cgraph * gf,
+                    const std::vector<llama_kv_defrag_move> & moves) const;
 
     void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
     void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;