0011-llama-Ensure-KV-cache-is-fully-defragmented.patch 9.39 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Jesse Gross <jesse@ollama.com>
Date: Fri, 13 Dec 2024 16:11:59 -0800
Subject: [PATCH] llama: Ensure KV cache is fully defragmented.

Sometimes the KV cache requires defragmentation even without
triggering the threshold heuristic. In this case, decoding
will not being able to find a KV cache slot. This is particularly
difficult for the caller to handle if it happens in between
ubatches. To avoid this, we should immediately trigger a defrag.

In addition, a heavily fragmented cache can require more than
max_moves to defragment. Currently, we stop when we hit the limit
but this can leave a cache that still does not have adequate space
even after defragmentation is triggered. Instead, we should do
multiple batches of processing until everything is complete.
---
 src/llama.cpp | 99 ++++++++++++++++++++++++---------------------------
 1 file changed, 46 insertions(+), 53 deletions(-)

diff --git a/src/llama.cpp b/src/llama.cpp
22
index 8f7902df..01854fce 100644
23
24
--- a/src/llama.cpp
+++ b/src/llama.cpp
25
@@ -1054,6 +1054,13 @@ static struct ggml_tensor * llm_build_rwkv6_channel_mix(
26
27
     return ggml_mul(ctx, r, llm_build_lora_mm(lctx, ctx, layer->channel_mix_value, k));
 }
28
29
30
31
32
33
34
35
 
+// block of KV slots to move when defragging
+struct llama_kv_defrag_move {
+    uint32_t src;
+    uint32_t dst;
+    uint32_t len;
+};
+
36
37
38
 struct llm_build_context {
     const llama_model    & model;
           llama_context  & lctx;
39
@@ -1230,35 +1237,23 @@ struct llm_build_context {
40
41
42
43
44
         return gf;
     }
 
-    struct ggml_cgraph * build_defrag(const std::vector<uint32_t> & ids) {
+    struct ggml_cgraph * build_defrag(const std::vector<struct llama_kv_defrag_move> & moves) {
45
         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
 
-        for (uint32_t i = 0; i < ids.size(); ++i) {
-            const uint32_t id = ids[i];
-
-            if (i == id || id == ids.size()) {
-                continue;
-            }
-
-            uint32_t nm = 1;
-
-            while (i + nm < ids.size() && ids[i + nm] == id + nm) {
-                nm++;
-            }
-
+        for (const auto & move : moves) {
             for (int il = 0; il < n_layer; ++il) {
                 const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
                 const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
 
                 ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv_self.k_l[il],
-                        n_embd_k_gqa, nm,
+                        n_embd_k_gqa, move.len,
                         ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
-                        ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*i));
+                        ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*move.src));
 
                 ggml_tensor * view_k_dst = ggml_view_2d(ctx0, kv_self.k_l[il],
-                        n_embd_k_gqa, nm,
+                        n_embd_k_gqa, move.len,
                         ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
-                        ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*id));
+                        ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*move.dst));
 
                 ggml_tensor * view_v_src;
                 ggml_tensor * view_v_dst;
81
@@ -1266,31 +1261,29 @@ struct llm_build_context {
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
                 if (flash_attn) {
                     // NOTE: the V cache is not transposed when using flash attention
                     view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il],
-                            n_embd_v_gqa, nm,
+                            n_embd_v_gqa, move.len,
                             ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
-                            ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*i));
+                            ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*move.src));
 
                     view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il],
-                            n_embd_v_gqa, nm,
+                            n_embd_v_gqa, move.len,
                             ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
-                            ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*id));
+                            ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*move.dst));
                 } else {
                     view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il],
-                            nm, n_embd_v_gqa,
+                            move.len, n_embd_v_gqa,
                             ggml_row_size(kv_self.v_l[il]->type, kv_self.size),
-                            ggml_row_size(kv_self.v_l[il]->type, i));
+                            ggml_row_size(kv_self.v_l[il]->type, move.src));
 
                     view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il],
-                            nm, n_embd_v_gqa,
+                            move.len, n_embd_v_gqa,
                             ggml_row_size(kv_self.v_l[il]->type, kv_self.size),
-                            ggml_row_size(kv_self.v_l[il]->type, id));
+                            ggml_row_size(kv_self.v_l[il]->type, move.dst));
                 }
 
                 ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_k_src, view_k_dst));
                 ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_v_src, view_v_dst));
             }
-
-            i += nm - 1;
         }
 
         //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
121
@@ -8508,7 +8501,7 @@ struct llm_build_context {
122
123
124
125
126
127
128
129
     }
 };
 
-static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
+static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<struct llama_kv_defrag_move> & moves) {
     llama_ubatch dummy = {};
     dummy.equal_seqs = true;
 
130
@@ -8518,7 +8511,7 @@ static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const
131
132
133
134
135
136
137
138
 
     llm.init();
 
-    struct ggml_cgraph * result = llm.build_defrag(ids);
+    struct ggml_cgraph * result = llm.build_defrag(moves);
 
     llm.free();
 
139
140
141
@@ -8956,7 +8949,12 @@ static int llama_prepare_ubatch(
             kv_self.head = 0;
         }
142
 
143
144
145
146
147
148
149
150
151
152
153
-        const auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
+        auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
+        if (!slot) {
+            llama_kv_cache_defrag(kv_self);
+            llama_kv_cache_update(&lctx);
+            slot = llama_kv_cache_find_slot(kv_self, ubatch);
+        }
         if (!slot) {
             return 1;
         }
@@ -9431,8 +9429,8 @@ static void llama_kv_cache_defrag_impl(struct llama_context & lctx) {
154
155
156
157
158
159
160
161
162
163
 
     //const int64_t t_start = ggml_time_us();
 
-    // number of cells moved
-    uint32_t n_moves = 0;
+    // groups of cells moved
+    std::vector<struct llama_kv_defrag_move> moves;
 
     // each move requires 6*n_layer tensors (see build_defrag)
     //   - source view, destination view, copy operation
164
@@ -9496,19 +9494,11 @@ static void llama_kv_cache_defrag_impl(struct llama_context & lctx) {
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
         // are we moving a continuous block of memory?
         bool cont = false;
 
-        // should we stop searching for the next move?
-        bool stop = false;
-
         // go back and move the nf cells to the hole
         for (; i1 < n_kv; ++i1) {
             auto & cell1 = kv_self.cells[i1];
 
             if (cell1.is_empty() || ids[i1] != n_kv) {
-                if (n_moves == max_moves) {
-                    stop = true;
-                    break;
-                }
-
                 cont = false;
                 continue;
             }
184
@@ -9524,8 +9514,10 @@ static void llama_kv_cache_defrag_impl(struct llama_context & lctx) {
185
186
187
188
189
190
191
192
193
194
195
             kv_self.head = n_used;
 
             if (!cont) {
-                n_moves++;
+                moves.push_back({i1, i0 + nf, 1});
                 cont = true;
+            } else {
+                moves.back().len++;
             }
 
             nf++;
196
@@ -9535,22 +9527,16 @@ static void llama_kv_cache_defrag_impl(struct llama_context & lctx) {
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
             }
         }
 
-        if (stop || n_moves == max_moves) {
-            break;
-        }
-
         //LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh);
 
         i0 += nh - 1;
     }
 
-    if (n_moves == 0) {
+    if (moves.size() == 0) {
         return;
     }
 
-    //LLAMA_LOG_INFO("(tmp log) KV defrag cell moves: %u\n", n_moves);
-
-    //LLAMA_LOG_INFO("expected gf nodes: %u\n", 6*n_moves*n_layer);
+    //LLAMA_LOG_INFO("(tmp log) KV defrag cell moves: %u\n",  moves.size());
 
 #if 0
     // CPU defrag
221
@@ -9625,11 +9611,18 @@ static void llama_kv_cache_defrag_impl(struct llama_context & lctx) {
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
 #else
     // ggml_graph defrag
 
-    ggml_backend_sched_reset(lctx.sched.get());
+    for (std::size_t i = 0; i < moves.size(); i += max_moves) {
+        std::vector<struct llama_kv_defrag_move> chunk;
+        auto end = std::min(i + max_moves, moves.size());
+        chunk.assign(moves.begin() + i, moves.begin() + end);
 
-    ggml_cgraph * gf = llama_build_graph_defrag(lctx, ids);
+        ggml_backend_sched_reset(lctx.sched.get());
+
+        //LLAMA_LOG_INFO("expected gf nodes: %u\n", 6*chunk.size()*n_layer);
+        ggml_cgraph * gf = llama_build_graph_defrag(lctx, chunk);
 
-    llama_graph_compute(lctx, gf, lctx.cparams.n_threads, lctx.threadpool);
+        llama_graph_compute(lctx, gf, lctx.cparams.n_threads, lctx.threadpool);
+    }
 #endif
 
     //const int64_t t_end = ggml_time_us();