0003-embeddings.patch 2.12 KB
Newer Older
1
2
3
4
5
6
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Michael Yang <mxyng@pm.me>
Date: Mon, 16 Sep 2024 15:53:14 -0700
Subject: [PATCH] embeddings

---
7
8
9
 src/llama-context.cpp | 2 +-
 src/llama.cpp         | 6 ++++--
 2 files changed, 5 insertions(+), 3 deletions(-)
10

11
diff --git a/src/llama-context.cpp b/src/llama-context.cpp
12
index 671d2a81..47e79ed4 100644
13
14
--- a/src/llama-context.cpp
+++ b/src/llama-context.cpp
15
@@ -479,7 +479,7 @@ size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs) {
16
     const auto n_embd  = hparams.n_embd;
17
 
18
19
20
21
     // TODO: use a per-batch flag for logits presence instead
-    const bool has_logits = !cparams.embeddings;
+    const bool has_logits =  cparams.causal_attn;
     const bool has_embd   =  cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
22
 
23
     const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0;
24
diff --git a/src/llama.cpp b/src/llama.cpp
25
index 607f2786..ac85bfed 100644
26
27
--- a/src/llama.cpp
+++ b/src/llama.cpp
28
@@ -8652,7 +8652,6 @@ static int llama_decode_impl(
29
30
             res  = nullptr;
             embd = nullptr;
31
         } else if (cparams.embeddings) {
32
-            res  = nullptr; // do not extract logits for embedding case
33
             embd = nullptr;
34
             for (int i = ggml_graph_n_nodes(gf) - 1; i >= 0; --i) {
35
                 if (strcmp(ggml_graph_node(gf, i)->name, "result_embd_pooled") == 0) {
36
@@ -8660,12 +8659,15 @@ static int llama_decode_impl(
37
38
39
40
41
42
43
44
                     break;
                 }
             }
-            GGML_ASSERT(embd != nullptr && "missing embeddings tensor");
         } else {
             embd = nullptr; // do not extract embeddings when not needed
             GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
         }
45
 
46
47
48
+        if (!cparams.causal_attn) {
+            res = nullptr; // do not extract logits when not needed
+        }
49
+
50
         // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
51
 
52
         ggml_backend_sched_alloc_graph(lctx.sched.get(), gf);