0003-embeddings.patch 1.67 KB
Newer Older
1
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
2
3
From: jmorganca <jmorganca@gmail.com>
Date: Tue, 8 Apr 2025 15:28:34 -0700
4
5
Subject: [PATCH] embeddings

6
7
8
allow a loaded model in llama.cpp to be used for
both embeddings and causal attention text generation
instead of forcing one or the error
9
---
10
11
 src/llama-context.cpp | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)
12

13
diff --git a/src/llama-context.cpp b/src/llama-context.cpp
14
index 5a2eef9b..9c1fe93f 100644
15
16
--- a/src/llama-context.cpp
+++ b/src/llama-context.cpp
17
@@ -1225,7 +1225,7 @@ int llama_context::decode(llama_batch & inp_batch) {
18
     int64_t n_outputs_all = 0;
19
 
20
21
22
23
24
     // count outputs
-    if (batch.logits && !embd_pooled) {
+    if (batch.logits) {
         for (uint32_t i = 0; i < n_tokens_all; ++i) {
             n_outputs_all += batch.logits[i] != 0;
25
         }
26
@@ -1337,7 +1337,7 @@ int llama_context::decode(llama_batch & inp_batch) {
27
28
29
30
31
32
         //    ggml_graph_dump_dot(gf, NULL, "llama.dot");
         //}
 
-        auto * t_logits = cparams.embeddings ? nullptr         : res->get_logits();
+        auto * t_logits = cparams.causal_attn ? res->get_logits() : nullptr;
         auto * t_embd   = cparams.embeddings ? res->get_embd() : nullptr;
33
 
34
         if (t_embd && res->get_embd_pooled()) {
35
@@ -1481,7 +1481,7 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
36
37
38
39
40
41
     const auto n_embd  = hparams.n_embd;
 
     // TODO: use a per-batch flag for logits presence instead
-    bool has_logits = !cparams.embeddings;
+    bool has_logits =  cparams.causal_attn;
     bool has_embd   =  cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
42
 
43
     // TODO: hacky enc-dec support