0010-ensure-KV-cache-is-fully-defragmented.patch 11.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: jmorganca <jmorganca@gmail.com>
Date: Tue, 15 Apr 2025 14:27:40 -0400
Subject: [PATCH] ensure KV cache is fully defragmented

Sometimes the KV cache requires defragmentation even without
triggering the threshold heuristic. In this case, decoding
will not being able to find a KV cache slot. This is particularly
difficult for the caller to handle if it happens in between
ubatches. To avoid this, we should immediately trigger a defrag.

In addition, a heavily fragmented cache can require more than
max_moves to defragment. Currently, we stop when we hit the limit
but this can leave a cache that still does not have adequate space
even after defragmentation is triggered. Instead, we should do
multiple batches of processing until everything is complete.
---
18
19
20
21
 src/llama-context.h    |   1 +
 src/llama-kv-cache.cpp | 107 ++++++++++++++---------------------------
 src/llama-kv-cache.h   |  12 ++++-
 3 files changed, 47 insertions(+), 73 deletions(-)
22

23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
diff --git a/src/llama-context.h b/src/llama-context.h
index c4ab242a..9970dfc6 100644
--- a/src/llama-context.h
+++ b/src/llama-context.h
@@ -5,6 +5,7 @@
 #include "llama-cparams.h"
 #include "llama-graph.h"
 #include "llama-adapter.h"
+#include "llama-kv-cache.h"
 
 #include "ggml-cpp.h"
 #include "ggml-opt.h"
diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp
index a7b0a7eb..1a50c034 100644
--- a/src/llama-kv-cache.cpp
+++ b/src/llama-kv-cache.cpp
@@ -372,8 +372,6 @@ void llama_kv_cache_unified::commit() {
 }
 
 bool llama_kv_cache_unified::update(llama_context & lctx) {
-    bool need_reserve = false;
-
     auto * sched = lctx.get_sched();
 
     if (has_shift) {
@@ -396,8 +394,6 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
             res->set_inputs(nullptr);
 
             lctx.graph_compute(gf, false);
-
-            need_reserve = true;
         }
 
         {
@@ -411,27 +407,36 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
 
     if (do_defrag) {
         LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
+        const uint32_t n_max_nodes = lctx.graph_max_nodes();
+        const uint32_t max_moves = (n_max_nodes - 2*model.hparams.n_layer)/(6*model.hparams.n_layer);
+        if (!defrag_prepare(n_max_nodes)) {
+            LLAMA_LOG_ERROR("%s: failed to prepare defragmentation\n", __func__);
+            return false;
+        }
+
+        for (std::size_t i = 0; i < defrag_info.moves.size(); i += max_moves) {
+            std::vector<struct llama_kv_defrag_move> chunk;
+            auto end = std::min(i + max_moves, defrag_info.moves.size());
+            chunk.assign(defrag_info.moves.begin() + i, defrag_info.moves.begin() + end);
 
-        if (defrag_prepare(lctx.graph_max_nodes())) {
             ggml_backend_sched_reset(sched);
 
             auto * gf = lctx.graph_init();
 
-            auto res = build_graph_defrag(lctx.get_cparams(), lctx.get_ctx_compute(), gf);
+            auto res = build_graph_defrag(lctx.get_cparams(), lctx.get_ctx_compute(), gf, chunk);
 
             ggml_backend_sched_alloc_graph(sched, gf);
 
             res->set_inputs(nullptr);
 
             lctx.graph_compute(gf, false);
-
-            need_reserve = true;
         }
 
         do_defrag = false;
     }
92
 
93
94
95
96
-    return need_reserve;
+    // we never need to reserve a worst case graph
+    return false;
 }
97
 
98
99
100
101
102
103
104
105
106
107
108
 void llama_kv_cache_unified::defrag_sched(float thold) {
@@ -715,11 +720,10 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
 llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
         const llama_cparams & cparams,
                ggml_context * ctx,
-                ggml_cgraph * gf) const {
+                ggml_cgraph * gf,
+                const std::vector<struct llama_kv_defrag_move> & moves) const {
     auto res = std::make_unique<llm_graph_result>();
 
-    const auto & ids = defrag_info.ids;
109
110
111
112
-
 #if 0
     // CPU defrag
     //
113
@@ -791,32 +795,20 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
         ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size());
     }
 #else
-    for (uint32_t i = 0; i < ids.size(); ++i) {
-        const uint32_t id = ids[i];
-
-        if (i == id || id == ids.size()) {
-            continue;
-        }
-
-        uint32_t nm = 1;
-
-        while (i + nm < ids.size() && ids[i + nm] == id + nm) {
-            nm++;
-        }
-
+    for (const auto & move : moves) {
         for (uint32_t il = 0; il < hparams.n_layer; ++il) { // NOLINT
             const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
             const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
 
135
             ggml_tensor * view_k_src = ggml_view_2d(ctx, k_l[il],
136
137
-                    n_embd_k_gqa, nm,
+                    n_embd_k_gqa, move.len,
138
139
140
                     ggml_row_size(k_l[il]->type, n_embd_k_gqa),
-                    ggml_row_size(k_l[il]->type, n_embd_k_gqa*i));
+                    ggml_row_size(k_l[il]->type, n_embd_k_gqa*move.src));
141
 
142
             ggml_tensor * view_k_dst = ggml_view_2d(ctx, k_l[il],
143
144
-                    n_embd_k_gqa, nm,
+                    n_embd_k_gqa, move.len,
145
146
147
                     ggml_row_size(k_l[il]->type, n_embd_k_gqa),
-                    ggml_row_size(k_l[il]->type, n_embd_k_gqa*id));
+                    ggml_row_size(k_l[il]->type, n_embd_k_gqa*move.dst));
148
149
150
 
             ggml_tensor * view_v_src;
             ggml_tensor * view_v_dst;
151
@@ -824,31 +816,29 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
152
153
             if (cparams.flash_attn) {
                 // NOTE: the V cache is not transposed when using flash attention
154
                 view_v_src = ggml_view_2d(ctx, v_l[il],
155
156
-                        n_embd_v_gqa, nm,
+                        n_embd_v_gqa, move.len,
157
158
159
                         ggml_row_size(v_l[il]->type, n_embd_v_gqa),
-                        ggml_row_size(v_l[il]->type, n_embd_v_gqa*i));
+                        ggml_row_size(v_l[il]->type, n_embd_v_gqa*move.dst));
160
 
161
                 view_v_dst = ggml_view_2d(ctx, v_l[il],
162
-                        n_embd_v_gqa, nm,
163
164
165
166
+                        move.len, n_embd_v_gqa,
                         ggml_row_size(v_l[il]->type, n_embd_v_gqa),
-                        ggml_row_size(v_l[il]->type, n_embd_v_gqa*id));
+                        ggml_row_size(v_l[il]->type, move.src));
167
             } else {
168
                 view_v_src = ggml_view_2d(ctx, v_l[il],
169
170
-                        nm, n_embd_v_gqa,
+                        move.len, n_embd_v_gqa,
171
172
173
                         ggml_row_size(v_l[il]->type, size),
-                        ggml_row_size(v_l[il]->type, i));
+                        ggml_row_size(v_l[il]->type, move.src));
174
 
175
                 view_v_dst = ggml_view_2d(ctx, v_l[il],
176
177
-                        nm, n_embd_v_gqa,
+                        move.len, n_embd_v_gqa,
178
179
180
                         ggml_row_size(v_l[il]->type, size),
-                        ggml_row_size(v_l[il]->type, id));
+                        ggml_row_size(v_l[il]->type, move.dst));
181
182
             }
 
183
184
             ggml_build_forward_expand(gf, ggml_cpy(ctx, view_k_src, view_k_dst));
             ggml_build_forward_expand(gf, ggml_cpy(ctx, view_v_src, view_v_dst));
185
186
187
188
189
         }
-
-        i += nm - 1;
     }
 
190
191
     //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
@@ -865,17 +855,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
 
     assert(n_used <= n_kv);
 
-    //const int64_t t_start = ggml_time_us();
-
-    // number of cells moved
-    uint32_t n_moves = 0;
-
-    // each move requires 6*n_layer tensors (see graph_build_kv_self_defrag)
-    //   - source view, destination view, copy operation
-    //   - x2 for keys and values
-    //const uint32_t max_moves = max_nodes()/(6*n_layer);
-    // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516
-    const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer);
+    defrag_info.moves.clear();
 
     // determine which KV cells to move where
     //
210
@@ -883,10 +863,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
211
212
213
214
215
216
217
218
219
220
221
     //
     //  if ids[i] == i || ids[i] == n_kv, then cell i is not moved
     //
-    auto & ids = defrag_info.ids;
-
-    ids.clear();
-    ids.resize(n_kv, n_kv);
+    std::vector<uint32_t> ids(n_kv, n_kv);
 
     for (uint32_t i0 = 0; i0 < n_used; ++i0) {
         const auto & cell0 = cells[i0];
222
@@ -935,19 +912,11 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
         // are we moving a continuous block of memory?
         bool cont = false;
 
-        // should we stop searching for the next move?
-        bool stop = false;
-
         // go back and move the nf cells to the hole
         for (; i1 < n_kv; ++i1) {
             auto & cell1 = cells[i1];
 
             if (cell1.is_empty() || ids[i1] != n_kv) {
-                if (n_moves == max_moves) {
-                    stop = true;
-                    break;
-                }
-
                 cont = false;
                 continue;
             }
242
@@ -963,8 +932,10 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
243
244
245
246
247
248
249
250
251
252
253
             head = n_used;
 
             if (!cont) {
-                n_moves++;
+                defrag_info.moves.push_back({i1, i0 + nf, 1});
                 cont = true;
+            } else {
+                defrag_info.moves.back().len++;
             }
 
             nf++;
254
@@ -974,22 +945,16 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
             }
         }
 
-        if (stop || n_moves == max_moves) {
-            break;
-        }
-
         //LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh);
 
         i0 += nh - 1;
     }
 
-    if (n_moves == 0) {
+    if (defrag_info.moves.size() == 0) {
         return false;
     }
 
272
-    LLAMA_LOG_DEBUG("%s: (tmp log) KV defrag cell moves: %u\n", __func__, n_moves);
273
-
274
-    LLAMA_LOG_DEBUG("%s: expected gf nodes: %u\n", __func__, 6*n_moves*n_layer);
275
276
277
278
279
+    // LLAMA_LOG_DEBUG("(tmp log) KV defrag cell moves: %u\n", n_moves);
 
     return true;
 }
diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h
280
index bf3b4b6a..928b9712 100644
281
282
--- a/src/llama-kv-cache.h
+++ b/src/llama-kv-cache.h
283
284
@@ -82,6 +82,13 @@ struct llama_kv_cache_guard {
 private:
285
286
     llama_kv_cache * kv;
 };
287
+ 
288
289
290
291
292
293
294
+// block of KV slots to move when defragging
+struct llama_kv_defrag_move {
+    uint32_t src;
+    uint32_t dst;
+    uint32_t len;
+};
 
295
296
297
298
299
 //
 // llama_kv_cache_unified
@@ -207,7 +214,7 @@ private:
 
     // defrag
300
301
302
303
304
305
     struct {
-        std::vector<uint32_t> ids;
+        std::vector<llama_kv_defrag_move> moves;
     } defrag_info;
 
     // return true if cells have been moved
306
307
308
309
310
311
312
313
314
315
@@ -249,7 +256,8 @@ private:
     llm_graph_result_ptr build_graph_defrag(
             const llama_cparams & cparams,
                    ggml_context * ctx,
-                    ggml_cgraph * gf) const;
+                    ggml_cgraph * gf,
+                    const std::vector<llama_kv_defrag_move> & moves) const;
 
     void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
     void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;