0006-embeddings.patch 2.1 KB
Newer Older
1
2
3
4
From 235b6d876a74cb09abe26985fa89ebe5bfc9f562 Mon Sep 17 00:00:00 2001
From: Gabe Goodhart <ghart@us.ibm.com>
Date: Thu, 19 Sep 2024 17:06:17 -0600
Subject: [PATCH] embeddings
Michael Yang's avatar
Michael Yang committed
5
6

---
7
8
 src/llama.cpp | 15 +++++++++------
 1 file changed, 9 insertions(+), 6 deletions(-)
Michael Yang's avatar
Michael Yang committed
9

10
diff --git a/src/llama.cpp b/src/llama.cpp
11
index 1a8e0c51..e55ec3f8 100644
12
13
--- a/src/llama.cpp
+++ b/src/llama.cpp
14
@@ -16516,7 +16516,7 @@ static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) {
15
16
17
18
19
     const auto n_embd  = hparams.n_embd;
 
     // TODO: use a per-batch flag for logits presence instead
-    const bool has_logits = !cparams.embeddings;
+    const bool has_logits =  cparams.causal_attn;
20
     const bool has_embd   =  cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
21
22
 
     const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0;
23
@@ -16794,20 +16794,23 @@ static int llama_decode_internal(
24
25
26
27
             // no output
             res  = nullptr;
             embd = nullptr;
-        } else if (cparams.embeddings) {
28
29
-            res  = nullptr; // do not extract logits for embedding case
-            embd = nullptr;
30
31
32
+        }
+
+        if (cparams.embeddings) {
33
34
35
36
             for (int i = ggml_graph_n_nodes(gf) - 1; i >= 0; --i) {
+                embd = ggml_graph_node(gf, i);
                 if (strcmp(ggml_graph_node(gf, i)->name, "result_embd_pooled") == 0) {
-                    embd = ggml_graph_node(gf, i);
37
38
                     break;
                 }
39
             }
40
41
-            GGML_ASSERT(embd != nullptr && "missing embeddings tensor");
         } else {
42
43
44
45
46
47
48
49
50
51
             embd = nullptr; // do not extract embeddings when not needed
             GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
         }
+
+        if (!cparams.causal_attn) {
+            res = nullptr; // do not extract logits when not needed
+        }
         // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
 
         ggml_backend_sched_alloc_graph(lctx.sched, gf);
Michael Yang's avatar
Michael Yang committed
52
-- 
53
2.39.3 (Apple Git-146)
Michael Yang's avatar
Michael Yang committed
54