"fs/ggml/gguf.go" did not exist on "a49d6acc1eff460655ee777989a79a765afc1402"
0006-add-mllama-support.patch 44.4 KB
Newer Older
1
2
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: jmorganca <jmorganca@gmail.com>
3
Date: Sun, 20 Apr 2025 16:12:36 -0700
4
5
Subject: [PATCH] add mllama support

6
adds support for the llama 3.2 vision architecture
7
---
8
9
 ggml/src/ggml-backend-reg.cpp |   6 +-
 include/llama.h               |   6 +
10
 src/llama-arch.cpp            |  44 +++++
11
12
 src/llama-arch.h              |  10 ++
 src/llama-batch.cpp           |   3 +
13
 src/llama-context.cpp         |  23 ++-
14
 src/llama-context.h           |   1 +
15
 src/llama-cparams.h           |   1 +
16
17
18
19
 src/llama-graph.cpp           |  25 +++
 src/llama-graph.h             |  12 ++
 src/llama-hparams.cpp         |   4 +
 src/llama-hparams.h           |   7 +
20
 src/llama-kv-cache.cpp        |  14 +-
21
 src/llama-model-loader.cpp    |   2 +
22
 src/llama-model.cpp           | 311 +++++++++++++++++++++++++++++++++-
23
 src/llama-model.h             |  12 ++
24
 src/llama-quant.cpp           |   4 +-
25
26
27
 tools/mtmd/llava.cpp          |   5 +-
 tools/mtmd/mtmd-helper.cpp    |   7 +-
 19 files changed, 475 insertions(+), 22 deletions(-)
28

29
diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp
30
index 405d8e31..82ae1b5b 100644
31
32
--- a/ggml/src/ggml-backend-reg.cpp
+++ b/ggml/src/ggml-backend-reg.cpp
33
@@ -178,9 +178,9 @@ struct ggml_backend_registry {
34
35
36
37
38
39
40
41
42
43
44
45
 #ifdef GGML_USE_CANN
         register_backend(ggml_backend_cann_reg());
 #endif
-#ifdef GGML_USE_BLAS
-        register_backend(ggml_backend_blas_reg());
-#endif
+// #ifdef GGML_USE_BLAS
+//         register_backend(ggml_backend_blas_reg());
+// #endif
 #ifdef GGML_USE_RPC
         register_backend(ggml_backend_rpc_reg());
 #endif
46
diff --git a/include/llama.h b/include/llama.h
47
index abedebdb..41beef21 100644
48
49
--- a/include/llama.h
+++ b/include/llama.h
50
@@ -258,6 +258,7 @@ extern "C" {
51
52
53
54
55
56
57
 
         llama_token  *  token;
         float        *  embd;
+        int32_t         n_embd;
         llama_pos    *  pos;
         int32_t      *  n_seq_id;
         llama_seq_id ** seq_id;
58
@@ -365,6 +366,7 @@ extern "C" {
59
60
         bool flash_attn;  // whether to use flash attention [EXPERIMENTAL]
         bool no_perf;     // whether to measure performance timings
61
         bool op_offload;  // whether to offload host tensor operations to device
62
+        bool cross_attn;  // whether to use cross attention
63
     };
64
 
65
66
     // model quantization parameters
@@ -464,6 +466,10 @@ extern "C" {
67
68
             struct llama_context_params   params),
             "use llama_init_from_model instead");
69
70
71
 
+    // TODO (jmorganca): this should most likely be passed in as part of a batch
+    // and not set on the context for all batches.
72
+    LLAMA_API void llama_set_cross_attention(struct llama_context * ctx, bool cross_attn_state);
73
74
75
76
+
     // Frees all allocated memory
     LLAMA_API void llama_free(struct llama_context * ctx);
 
77
diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp
78
index 5ab3f572..eb7b5325 100644
79
80
81
--- a/src/llama-arch.cpp
+++ b/src/llama-arch.cpp
@@ -6,6 +6,7 @@
82
83
 
 static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
84
85
     { LLM_ARCH_LLAMA,            "llama"            },
+    { LLM_ARCH_MLLAMA,           "mllama"           },
86
     { LLM_ARCH_LLAMA4,           "llama4"           },
87
88
     { LLM_ARCH_DECI,             "deci"             },
     { LLM_ARCH_FALCON,           "falcon"           },
89
@@ -144,6 +145,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
90
91
92
93
     { LLM_KV_ATTENTION_SLIDING_WINDOW,               "%s.attention.sliding_window"               },
     { LLM_KV_ATTENTION_SCALE,                        "%s.attention.scale"                        },
     { LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION,        "%s.attention.block_skip_connection"        },
+    { LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS,       "%s.attention.cross_attention_layers"       },
94
95
     { LLM_KV_ATTENTION_KEY_LENGTH_MLA,               "%s.attention.key_length_mla"               },
     { LLM_KV_ATTENTION_VALUE_LENGTH_MLA,             "%s.attention.value_length_mla"             },
96
 
97
@@ -273,6 +275,40 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
98
             { LLM_TENSOR_FFN_UP_SHEXP,    "blk.%d.ffn_up_shexp" },
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
         },
     },
+    {
+        LLM_ARCH_MLLAMA,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
+            { LLM_TENSOR_OUTPUT,          "output" },
+            { LLM_TENSOR_ROPE_FREQS,      "rope_freqs" },
+            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
+            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
+            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
+            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
+            { LLM_TENSOR_ATTN_ROT_EMBD,   "blk.%d.attn_rot_embd" },
+            { LLM_TENSOR_FFN_GATE_INP,    "blk.%d.ffn_gate_inp" },
+            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
+            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
+            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
+            { LLM_TENSOR_FFN_GATE_EXP,    "blk.%d.ffn_gate.%d" },
+            { LLM_TENSOR_FFN_DOWN_EXP,    "blk.%d.ffn_down.%d" },
+            { LLM_TENSOR_FFN_UP_EXP,      "blk.%d.ffn_up.%d" },
+            { LLM_TENSOR_FFN_GATE_EXPS,   "blk.%d.ffn_gate_exps" },
+            { LLM_TENSOR_FFN_DOWN_EXPS,   "blk.%d.ffn_down_exps" },
+            { LLM_TENSOR_FFN_UP_EXPS,     "blk.%d.ffn_up_exps" },
+            { LLM_TENSOR_CROSS_ATTN_K_NORM,    "blk.%d.cross_attn_k_norm" },
+            { LLM_TENSOR_CROSS_ATTN_K_PROJ,    "blk.%d.cross_attn_k_proj" },
+            { LLM_TENSOR_CROSS_ATTN_O_PROJ,    "blk.%d.cross_attn_o_proj" },
+            { LLM_TENSOR_CROSS_ATTN_Q_NORM,    "blk.%d.cross_attn_q_norm" },
+            { LLM_TENSOR_CROSS_ATTN_Q_PROJ,    "blk.%d.cross_attn_q_proj" },
+            { LLM_TENSOR_CROSS_ATTN_V_PROJ,    "blk.%d.cross_attn_v_proj" },
+            { LLM_TENSOR_CROSS_ATTN_ATTN_GATE, "blk.%d.cross_attn_attn_gate" },
+            { LLM_TENSOR_CROSS_ATTN_MLP_GATE,  "blk.%d.cross_attn_mlp_gate" },
+        },
+    },
     {
136
         LLM_ARCH_DECI,
137
         {
138
@@ -1701,6 +1737,14 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
     // this tensor is loaded for T5, but never used
     {LLM_TENSOR_DEC_CROSS_ATTN_REL_B,       {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}},
     {LLM_TENSOR_BSKCN_TV,                   {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
+    {LLM_TENSOR_CROSS_ATTN_K_NORM,          {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
+    {LLM_TENSOR_CROSS_ATTN_K_PROJ,          {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_CROSS_ATTN_O_PROJ,          {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_CROSS_ATTN_Q_NORM,          {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
+    {LLM_TENSOR_CROSS_ATTN_Q_PROJ,          {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_CROSS_ATTN_V_PROJ,          {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_CROSS_ATTN_ATTN_GATE,       {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
+    {LLM_TENSOR_CROSS_ATTN_MLP_GATE,        {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
     {LLM_TENSOR_CONV1D,                     {LLM_TENSOR_LAYER_INPUT,     GGML_OP_IM2COL}},
     {LLM_TENSOR_POS_NET_NORM,               {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
     {LLM_TENSOR_POS_NET_NORM1,              {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
diff --git a/src/llama-arch.h b/src/llama-arch.h
154
index 525c1b7d..bc8a4f0b 100644
155
156
--- a/src/llama-arch.h
+++ b/src/llama-arch.h
157
@@ -11,6 +11,7 @@
158
159
 enum llm_arch {
     LLM_ARCH_LLAMA,
160
     LLM_ARCH_LLAMA4,
161
162
163
164
+    LLM_ARCH_MLLAMA,
     LLM_ARCH_DECI,
     LLM_ARCH_FALCON,
     LLM_ARCH_BAICHUAN,
165
@@ -148,6 +149,7 @@ enum llm_kv {
166
167
168
169
     LLM_KV_ATTENTION_SLIDING_WINDOW,
     LLM_KV_ATTENTION_SCALE,
     LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION,
+    LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS,
170
171
     LLM_KV_ATTENTION_KEY_LENGTH_MLA,
     LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
172
 
173
@@ -349,6 +351,14 @@ enum llm_tensor {
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
     LLM_TENSOR_CLS,
     LLM_TENSOR_CLS_OUT,
     LLM_TENSOR_BSKCN_TV,
+    LLM_TENSOR_CROSS_ATTN_K_NORM,
+    LLM_TENSOR_CROSS_ATTN_K_PROJ,
+    LLM_TENSOR_CROSS_ATTN_O_PROJ,
+    LLM_TENSOR_CROSS_ATTN_Q_NORM,
+    LLM_TENSOR_CROSS_ATTN_Q_PROJ,
+    LLM_TENSOR_CROSS_ATTN_V_PROJ,
+    LLM_TENSOR_CROSS_ATTN_ATTN_GATE,
+    LLM_TENSOR_CROSS_ATTN_MLP_GATE,
     LLM_TENSOR_CONV1D,
     LLM_TENSOR_CONVNEXT_DW,
     LLM_TENSOR_CONVNEXT_NORM,
diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp
189
index a88b2fe3..241b316e 100644
190
191
--- a/src/llama-batch.cpp
+++ b/src/llama-batch.cpp
192
@@ -320,6 +320,7 @@ struct llama_batch llama_batch_get_one(
193
194
195
196
197
198
199
         /*n_tokens       =*/ n_tokens,
         /*tokens         =*/ tokens,
         /*embd           =*/ nullptr,
+        /*n_embd         =*/ 0,
         /*pos            =*/ nullptr,
         /*n_seq_id       =*/ nullptr,
         /*seq_id         =*/ nullptr,
200
@@ -332,6 +333,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
201
202
203
204
205
206
207
         /*n_tokens       =*/ 0,
         /*tokens         =*/ nullptr,
         /*embd           =*/ nullptr,
+        /*n_embd         =*/ 0,
         /*pos            =*/ nullptr,
         /*n_seq_id       =*/ nullptr,
         /*seq_id         =*/ nullptr,
208
@@ -340,6 +342,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
209
210
211
212
213
214
215
216
 
     if (embd) {
         batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
+        batch.n_embd = embd;
     } else {
         batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
     }
diff --git a/src/llama-context.cpp b/src/llama-context.cpp
217
index dca22d8b..c22687e4 100644
218
219
--- a/src/llama-context.cpp
+++ b/src/llama-context.cpp
220
@@ -514,7 +514,7 @@ float * llama_context::get_logits_ith(int32_t i) {
221
222
             throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, n_outputs));
         }
223
 
224
225
226
227
228
-        return logits + j*model.vocab.n_tokens();
+        return logits + j*model.hparams.n_vocab;
     } catch (const std::exception & err) {
         LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what());
 #ifndef NDEBUG
229
@@ -632,6 +632,10 @@ void llama_context::set_warmup(bool value) {
230
231
     cparams.warmup = value;
 }
232
 
233
234
235
236
237
238
239
+void llama_context::set_cross_attn(bool value) {
+    cparams.cross_attn = value;
+}
+
 void llama_context::set_adapter_lora(
             llama_adapter_lora * adapter,
             float scale) {
240
@@ -709,7 +713,7 @@ int llama_context::encode(llama_batch & inp_batch) {
241
 
242
243
     const int64_t n_embd = hparams.n_embd;
 
244
245
-    llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
+    llama_sbatch sbatch = llama_sbatch(batch, batch.n_embd, /* simple_split */ true, /* logits_all */ true);
246
247
248
 
     const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
 
249
@@ -863,10 +867,9 @@ int llama_context::decode(llama_batch & inp_batch) {
250
 
251
252
253
254
255
256
257
258
259
260
     const llama_batch & batch = batch_allocr.batch;
 
-    const auto & vocab   = model.vocab;
     const auto & hparams = model.hparams;
 
-    const int32_t n_vocab = vocab.n_tokens();
+    const int32_t n_vocab = hparams.n_vocab;
 
     const int64_t n_tokens_all = batch.n_tokens;
     const int64_t n_embd       = hparams.n_embd;
261
262
263
264
265
266
267
@@ -1087,7 +1090,7 @@ int llama_context::decode(llama_batch & inp_batch) {
         // make the outputs have the same order they had in the user-provided batch
         // note: this is mostly relevant for recurrent models atm
         if (!sorted_output) {
-            const uint32_t n_vocab = model.vocab.n_tokens();
+            const uint32_t n_vocab = model.hparams.n_vocab;
             const uint32_t n_embd  = model.hparams.n_embd;
268
 
269
270
             GGML_ASSERT((size_t) n_outputs == out_ids.size());
@@ -1142,12 +1145,11 @@ int llama_context::decode(llama_batch & inp_batch) {
271
272
273
274
275
276
 
 int32_t llama_context::output_reserve(int32_t n_outputs) {
     const auto & hparams = model.hparams;
-    const auto & vocab   = model.vocab;
 
     const int64_t n_outputs_max = std::max<int64_t>(n_outputs, n_seq_max());
277
278
279
280
281
282
283
 
     const auto n_batch = cparams.n_batch;
-    const auto n_vocab = vocab.n_tokens();
+    const auto n_vocab = hparams.n_vocab;
     const auto n_embd  = hparams.n_embd;
 
     // TODO: use a per-batch flag for logits presence instead
284
@@ -1682,7 +1684,7 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
285
286
287
288
289
     {
         LLAMA_LOG_DEBUG("%s: - writing logits\n", __func__);
 
-        const uint64_t logits_size = std::min((uint64_t) this->logits_size, (uint64_t) n_outputs * model.vocab.n_tokens());
+        const uint64_t logits_size = std::min((uint64_t) this->logits_size, (uint64_t) n_outputs * model.hparams.n_vocab);
290
 
291
292
         io.write(&logits_size, sizeof(logits_size));
 
293
@@ -2091,6 +2093,7 @@ llama_context_params llama_context_default_params() {
294
295
         /*.flash_attn                  =*/ false,
         /*.no_perf                     =*/ true,
296
         /*.op_offload                  =*/ true,
297
298
+        /*.cross_attn                  =*/ false,
     };
299
300
301
 
     return result;
@@ -2216,6 +2219,10 @@ void llama_set_warmup(llama_context * ctx, bool warmup) {
302
     ctx->set_warmup(warmup);
303
304
305
 }
 
+void llama_set_cross_attention(struct llama_context * ctx, bool cross_attention) {
306
+    ctx->set_cross_attn(cross_attention);
307
+}
308
+
309
310
311
 void llama_synchronize(llama_context * ctx) {
     ctx->synchronize();
 }
312
diff --git a/src/llama-context.h b/src/llama-context.h
313
index c0ceacb1..c4ab242a 100644
314
315
--- a/src/llama-context.h
+++ b/src/llama-context.h
316
@@ -71,6 +71,7 @@ struct llama_context {
317
318
319
320
     void set_embeddings (bool value);
     void set_causal_attn(bool value);
     void set_warmup(bool value);
+    void set_cross_attn(bool value);
321
 
322
323
     void set_adapter_lora(
             llama_adapter_lora * adapter,
324
diff --git a/src/llama-cparams.h b/src/llama-cparams.h
325
index 246fa577..7a6156ce 100644
326
327
--- a/src/llama-cparams.h
+++ b/src/llama-cparams.h
328
@@ -31,6 +31,7 @@ struct llama_cparams {
329
     bool no_perf;
330
     bool warmup;
331
332
     bool op_offload;
+    bool cross_attn;
333
334
 
     enum llama_pooling_type pooling_type;
335
 
336
diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp
337
index b0e3f635..f14869cf 100644
338
339
--- a/src/llama-graph.cpp
+++ b/src/llama-graph.cpp
340
@@ -532,6 +532,12 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
341
342
343
344
345
346
347
348
349
350
351
352
     }
 }
 
+void llm_graph_input_cross_attn_state::set_input(const llama_ubatch * ubatch) {
+    if (ubatch->embd) {
+        ggml_backend_tensor_set(cross_attn_state, ubatch->embd, 0, ggml_nbytes(cross_attn_state));
+    }
+}
+
 //
 // llm_graph_context
 //
353
@@ -1514,6 +1520,25 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
354
355
     return (llm_graph_input_attn_cross *) res->add_input(std::move(inp));
 }
356
 
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
+ggml_tensor * llm_graph_context::build_inp_cross_attn_state() const {
+    const int64_t n_embd = hparams.n_embd;
+
+    auto inp = std::make_unique<llm_graph_input_cross_attn_state>();
+
+    ggml_tensor * cur = nullptr;
+
+    inp->cross_attn_state = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd, 1601, 4);
+    ggml_set_input(inp->cross_attn_state);
+
+    cur = inp->cross_attn_state;
+
+    cb(cur, "inp_cross_attn_state", -1);
+
+    res->add_input(std::move(inp));
+
+    return cur;
+}
+
 ggml_tensor * llm_graph_context::build_attn(
         llm_graph_input_attn_cross * inp,
         ggml_cgraph * gf,
diff --git a/src/llama-graph.h b/src/llama-graph.h
380
index 832a8c09..5a322785 100644
381
382
--- a/src/llama-graph.h
+++ b/src/llama-graph.h
383
@@ -87,6 +87,7 @@ public:
384
385
386
387
388
389
390
 
     ggml_tensor * tokens = nullptr; // I32 [n_batch]
     ggml_tensor * embd   = nullptr; // F32 [n_embd, n_batch]
+    ggml_tensor * cross_attn_state; // F32 [4, n_embd, 1061]
 };
 
 class llm_graph_input_pos : public llm_graph_input_i {
391
@@ -284,6 +285,16 @@ public:
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
     const llama_cross * cross = nullptr;
 };
 
+class llm_graph_input_cross_attn_state : public llm_graph_input_i {
+public:
+    llm_graph_input_cross_attn_state()          = default;
+    virtual ~llm_graph_input_cross_attn_state() = default;
+
+    void set_input(const llama_ubatch * ubatch) override;
+
+    ggml_tensor * cross_attn_state; // F32 [4, n_embd, 1061]
+};
+
 //
 // llm_graph_result
 //
408
@@ -495,6 +506,7 @@ struct llm_graph_context {
409
410
411
412
413
414
415
     ggml_tensor * build_inp_cls() const;
     ggml_tensor * build_inp_s_copy() const;
     ggml_tensor * build_inp_s_mask() const;
+    ggml_tensor * build_inp_cross_attn_state() const;
 
     ggml_tensor * build_inp_cross_embd() const;
     ggml_tensor * build_inp_pos_bucket_enc() const;
416
diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp
417
index 8a667960..6a02de03 100644
418
419
--- a/src/llama-hparams.cpp
+++ b/src/llama-hparams.cpp
420
@@ -85,3 +85,7 @@ bool llama_hparams::is_swa(uint32_t il) const {
421
 
422
     GGML_ABORT("fatal error");
423
 }
424
+
425
426
+bool llama_hparams::cross_attention_layers(uint32_t il) const {
+    return std::find(cross_attn_layers.begin(), cross_attn_layers.end(), il) != cross_attn_layers.end();
427
+}
428
diff --git a/src/llama-hparams.h b/src/llama-hparams.h
429
index 48dce407..b6fc7e6d 100644
430
431
--- a/src/llama-hparams.h
+++ b/src/llama-hparams.h
432
433
434
435
436
437
438
439
440
441
@@ -2,6 +2,8 @@
 
 #include "llama.h"
 
+#include <algorithm>
+
 #include <array>
 
 // bump if necessary
@@ -42,6 +44,7 @@ struct llama_hparams {
442
443
444
445
446
     uint32_t n_expert = 0;
     uint32_t n_expert_used = 0;
     uint32_t n_rel_attn_bkts = 0;
+    uint32_t n_vocab = 0;
 
447
448
449
     // note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA
     uint32_t n_embd_head_k_mla = 0;
@@ -56,6 +59,7 @@ struct llama_hparams {
450
451
452
453
454
455
456
     std::array<uint32_t, LLAMA_MAX_LAYERS> n_ff_arr;
 
     std::array<std::array<uint32_t, LLAMA_MAX_LAYERS>, 4> n_bskcn_arr = {};
+    std::array<uint32_t, LLAMA_MAX_LAYERS> cross_attn_layers;
 
     uint32_t n_layer_dense_lead = 0;
     uint32_t n_lora_q           = 0;
457
@@ -159,6 +163,9 @@ struct llama_hparams {
458
459
     // Block skip connection
     bool n_bskcn(uint32_t n, uint32_t il) const;
460
 
461
+    // cross attention layers
462
+    bool cross_attention_layers(uint32_t il) const;
463
464
+
     bool is_swa(uint32_t il) const;
465
466
 };
 
467
diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp
468
index 3dcad65b..a7b0a7eb 100644
469
470
--- a/src/llama-kv-cache.cpp
+++ b/src/llama-kv-cache.cpp
471
472
@@ -100,8 +100,16 @@ llama_kv_cache_unified::llama_kv_cache_unified(
             throw std::runtime_error("failed to create ggml context for kv cache");
473
         }
474
 
475
476
477
478
-        ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
-        ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
+        ggml_tensor * k, *v;
+
479
+        // for cross attention layers
480
+        if (model.arch == LLM_ARCH_MLLAMA && hparams.cross_attention_layers(i)) {
481
482
483
484
485
+            k = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_k, 6404, hparams.n_head_kv(i));
+            v = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_v, 6404, hparams.n_head_kv(i));
+        } else {
+            k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
+            v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
486
+        }
487
488
         ggml_format_name(k, "cache_k_l%d", i);
         ggml_format_name(v, "cache_v_l%d", i);
489
         k_l.push_back(k);
490
491
492
493
494
495
496
497
498
@@ -446,7 +454,7 @@ void llama_kv_cache_unified::set_full() {
 llama_sbatch llama_kv_cache_unified::sbatch_init(
         const llama_batch & batch,
         bool logits_all) {
-    return llama_sbatch(batch, hparams.n_embd, true, logits_all);
+    return llama_sbatch(batch, batch.n_embd, true, logits_all);
 }
 
 llama_ubatch llama_kv_cache_unified::ubatch_next(
499
diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp
500
index 7f6617fa..2acfd4a8 100644
501
502
--- a/src/llama-model-loader.cpp
+++ b/src/llama-model-loader.cpp
503
@@ -315,6 +315,8 @@ namespace GGUFMeta {
504
         return true;
505
506
     }
 
507
508
509
510
511
512
+    template bool llama_model_loader::get_arr<std::array<unsigned int, 512>>(enum llm_kv kid, std::array<unsigned int, 512>& result, bool required);
+
     template<typename T, size_t N_MAX>
     bool llama_model_loader::get_arr(const std::string & key, std::array<T, N_MAX> & result, bool required) {
         const int kid = gguf_find_key(meta.get(), key.c_str());
diff --git a/src/llama-model.cpp b/src/llama-model.cpp
513
index 831b68c0..e8298f56 100644
514
515
--- a/src/llama-model.cpp
+++ b/src/llama-model.cpp
516
@@ -433,6 +433,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
517
518
519
520
521
522
523
 
     // get general kv
     ml.get_key(LLM_KV_GENERAL_NAME, name, false);
+    ml.get_key(LLM_KV_VOCAB_SIZE, hparams.n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, hparams.n_vocab, false);
 
     // everything past this point is not vocab-related
     if (hparams.vocab_only) {
524
@@ -444,6 +445,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
525
526
527
528
529
530
531
     ml.get_key(LLM_KV_BLOCK_COUNT,       hparams.n_layer);
     ml.get_key(LLM_KV_EXPERT_COUNT,      hparams.n_expert,      false);
     ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false);
+    ml.get_key(LLM_KV_VOCAB_SIZE,        hparams.n_vocab,       false);
 
     if (arch == LLM_ARCH_WAVTOKENIZER_DEC) {
         ml.get_key(LLM_KV_FEATURES_LENGTH, hparams.n_embd_features);
532
@@ -467,9 +469,11 @@ void llama_model::load_hparams(llama_model_loader & ml) {
533
534
535
     std::fill(hparams.n_head_arr.begin(),    hparams.n_head_arr.end(),    0);
     std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0);
     std::fill(hparams.n_ff_arr.begin(),      hparams.n_ff_arr.end(),      0);
536
+    std::fill(hparams.cross_attn_layers.begin(), hparams.cross_attn_layers.end(), -1);
537
 
538
539
     ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH,  hparams.n_ff_arr,   hparams.n_layer, false);
     ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer, false);
540
541
542
543
+    ml.get_arr(LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS, hparams.cross_attn_layers, false);
 
     // n_head_kv is optional, default to n_head
     hparams.n_head_kv_arr = hparams.n_head_arr;
544
@@ -522,7 +526,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
545
546
547
 
         ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
 
548
549
-        if (arch == LLM_ARCH_LLAMA || arch == LLM_ARCH_DECI || arch == LLM_ARCH_FALCON) {
+        if (arch == LLM_ARCH_LLAMA || arch == LLM_ARCH_MLLAMA || arch == LLM_ARCH_DECI || arch == LLM_ARCH_FALCON) {
550
551
552
             if (hparams.n_rot != hparams.n_embd_head_k) {
                 throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k));
             }
553
@@ -585,6 +589,16 @@ void llama_model::load_hparams(llama_model_loader & ml) {
554
                     hparams.use_kq_norm = false;
555
556
557
558
559
560
561
                 }
             } break;
+        case LLM_ARCH_MLLAMA:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+
+                switch (hparams.n_layer) {
562
563
564
+                    case 40: type = LLM_TYPE_11B; break;
+                    case 100: type = LLM_TYPE_90B; break;
+                    default: type = LLM_TYPE_UNKNOWN;
565
566
+                }
+            } break;
567
         case LLM_ARCH_DECI:
568
569
             {
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
570
@@ -1581,7 +1595,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
571
572
573
574
575
576
577
578
         const int64_t n_embd_head_v = hparams.n_embd_head_v;
         const int64_t n_ff          = hparams.n_ff();
         const int64_t n_embd_gqa    = n_embd_v_gqa;
-        const int64_t n_vocab       = vocab.n_tokens();
+        const int64_t n_vocab       = hparams.n_vocab;
         const int64_t n_token_types = vocab.n_token_types();
         const int64_t n_rot         = hparams.n_rot;
         const int64_t n_expert      = hparams.n_expert;
579
@@ -1840,6 +1854,52 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
580
                         }
581
582
583
584
                     }
                 } break;
+            case LLM_ARCH_MLLAMA:
+                {
585
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab+8}, 0);
586
587
588
+
+                    // output
+                    {
589
590
+                        output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                        output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
591
592
+
+                        // if output is NULL, init from the input tok embed
593
594
+                        if (output == NULL) {
+                            output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
595
596
597
598
+                        }
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
599
+                        auto & layer = layers[i];
600
+
601
+                        if (hparams.cross_attention_layers(i)) {
602
603
604
605
606
607
608
609
610
611
612
613
614
+                            layer.cross_attn_k_norm = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_K_NORM,   "weight", i), {128}, 0);
+                            layer.cross_attn_k_proj = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_K_PROJ,   "weight", i), {n_embd, 1024}, 0);
+                            layer.cross_attn_o_proj = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_O_PROJ,   "weight", i), {n_embd, n_embd}, 0);
+                            layer.cross_attn_q_norm = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_Q_NORM, "weight", i), {128}, 0);
+                            layer.cross_attn_q_proj = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_Q_PROJ, "weight", i), {n_embd, n_embd}, 0);
+                            layer.cross_attn_v_proj = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_V_PROJ, "weight", i), {n_embd, 1024}, 0);
+                            layer.cross_attn_attn_gate = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_ATTN_GATE, i), {1}, 0);
+                            layer.cross_attn_mlp_gate = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_MLP_GATE, i), {1}, 0);
+                            layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+                            layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
+                            layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                            layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                            layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
615
+                        } else {
616
617
618
619
620
621
622
623
624
625
+                            layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+                            layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
+                            layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
+                            layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
+                            layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
+                            layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+                            layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
+                            layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                            layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                            layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
626
627
628
+                        }
+                    }
+                } break;
629
             case LLM_ARCH_DECI:
630
                 {
631
                     tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
632
@@ -4756,6 +4816,246 @@ struct llm_build_llama : public llm_graph_context {
633
     }
634
 };
635
 
636
637
+struct llm_build_mllama: public llm_graph_context {
+    llm_build_mllama(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
638
639
640
641
642
643
644
+        // mutable variable, needed during the last layer of the computation to skip unused tokens
+        int32_t n_tokens = this->n_tokens;
+
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
645
646
647
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+        ggml_tensor * inpCAS;
648
+
649
650
+        inpL = build_inp_embd(model.tok_embd);
+        inpCAS = build_inp_cross_attn_state();
651
+
652
653
+          // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
654
+
655
656
+        auto * inp_attn = build_attn_inp_kv_unified();
+        const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
657
658
+
+        for (int il = 0; il < n_layer; ++il) {
659
+            ggml_tensor * inpSA = inpL;
660
661
+
+            // norm
662
+            cur = build_norm(inpL,
663
+                    model.layers[il].attn_norm, NULL,
664
+                    LLM_NORM_RMS, il);
665
666
+            cb(cur, "attn_norm", il);
+
667
+            if (hparams.cross_attention_layers(il)) {
668
+                if (!ubatch.embd && !cparams.cross_attn) {
669
670
671
672
+                    continue;
+                }
+
+                // cross attention layer
673
+                ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_q_proj, cur);
674
675
676
677
678
+                cb(Qcur, "Qcur", il);
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+                cb(Qcur, "Qcur", il);
+
679
+                Qcur = ggml_cont(ctx0, ggml_permute(ctx0, Qcur, 0, 2, 1, 3));
680
681
+                cb(Qcur, "Qcur", il);
+
682
+                Qcur = build_norm(Qcur, model.layers[il].cross_attn_q_norm, NULL, LLM_NORM_RMS, il);
683
684
+                cb(Qcur, "Qcur", il);
+
685
+                ggml_tensor * Kcur, * Vcur;
686
+                if (ubatch.embd) {
687
688
689
690
691
692
+                    Kcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_k_proj, inpCAS);
+                    cb(Kcur, "Kcur", il);
+
+                    Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, 6404);
+                    cb(Kcur, "Kcur", il);
+
693
+                    Kcur = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
694
695
+                    cb(Kcur, "Kcur", il);
+
696
+                    Kcur = build_norm(Kcur, model.layers[il].cross_attn_k_norm, NULL, LLM_NORM_RMS, il);
697
698
+                    cb(Kcur, "Kcur", il);
+
699
+                    ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, kv_self->k_l[il]));
700
701
702
703
704
705
706
707
708
709
+
+                    Vcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_v_proj, inpCAS);
+                    cb(Vcur, "Vcur", il);
+
+                    Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, 6404);
+                    cb(Vcur, "Vcur", il);
+
+                    Vcur = ggml_permute(ctx0, Vcur, 0, 2, 1, 3);
+                    cb(Vcur, "Vcur", il);
+
710
+                    ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, kv_self->v_l[il]));
711
+                } else {
712
+                    Kcur = ggml_view_tensor(ctx0, kv_self->k_l[il]);
713
714
+                    cb(Kcur, "Kcur (view)", il);
+
715
+                    Vcur = ggml_view_tensor(ctx0, kv_self->v_l[il]);
716
717
718
719
720
721
722
+                    cb(Vcur, "Vcur (view)", il);
+                }
+
+                struct ggml_tensor * kq = ggml_mul_mat(ctx0, Kcur, Qcur);
+                cb(kq, "kq", il);
+
+                // TODO: apply causal masks
723
+                struct ggml_tensor * kq_soft_max = ggml_soft_max_ext(ctx0, kq, nullptr, 1.f/sqrtf(float(n_embd_head)), hparams.f_max_alibi_bias);
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
+                cb(kq_soft_max, "kq_soft_max", il);
+
+                Vcur = ggml_cont(ctx0, ggml_transpose(ctx0, Vcur));
+                cb(Vcur, "Vcur", il);
+
+                struct ggml_tensor * kqv = ggml_mul_mat(ctx0, Vcur, kq_soft_max);
+                cb(kqv, "kqv", il);
+
+                struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
+                cb(kqv_merged, "kqv_merged", il);
+
+                cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens);
+                cb(cur, "kqv_merged_cont", il);
+
+                cur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_o_proj, cur);
+                cb(cur, "cur", il);
+
+                // TODO: do this in place once?
+                cur = ggml_mul(ctx0, cur, ggml_tanh(ctx0, model.layers[il].cross_attn_attn_gate));
+
+                struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+                cb(ffn_inp, "ffn_inp", il);
+
+                // feed-forward network
748
+                cur = build_norm(ffn_inp,
749
+                        model.layers[il].ffn_norm, NULL,
750
+                        LLM_NORM_RMS, il);
751
752
+                cb(cur, "ffn_norm", il);
+
753
+                cur = build_ffn(cur,
754
755
756
757
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                        model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
+                        NULL,
758
+                        LLM_FFN_SILU, LLM_FFN_PAR, il);
759
760
761
762
763
764
+                cb(cur, "ffn_out", il);
+
+                // TODO: do this inplace once?
+                cur = ggml_add_inplace(ctx0, ggml_mul_inplace(ctx0, cur, ggml_tanh(ctx0, model.layers[il].cross_attn_mlp_gate)), ffn_inp);
+                cb(cur, "ffn_out", il);
+
765
+                cur = build_cvec(cur, il);
766
767
768
769
770
771
772
773
+                cb(cur, "l_out", il);
+
+                // input for next layer
+                inpL = cur;
+            } else {
+                // self attention layer
+
+                // rope freq factors for llama3; may return nullptr for llama2 and other models
774
+                ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
775
776
+
+                // compute Q and K and RoPE them
777
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
778
779
780
781
782
783
+                cb(Qcur, "Qcur", il);
+                if (model.layers[il].bq) {
+                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                    cb(Qcur, "Qcur", il);
+                }
+
784
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
785
786
787
788
789
790
+                cb(Kcur, "Kcur", il);
+                if (model.layers[il].bk) {
+                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                    cb(Kcur, "Kcur", il);
+                }
+
791
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
792
793
794
795
796
797
+                cb(Vcur, "Vcur", il);
+                if (model.layers[il].bv) {
+                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                    cb(Vcur, "Vcur", il);
+                }
+
798
799
800
801
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
802
+                Qcur = ggml_rope_ext(
803
804
805
806
+                        ctx0, Qcur, inp_pos, rope_factors,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
807
808
+
+                Kcur = ggml_rope_ext(
809
810
811
812
+                        ctx0, Kcur, inp_pos, rope_factors,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
813
+
814
+                cb(Qcur, "Qcur", il);
815
+                cb(Kcur, "Kcur", il);
816
+                cb(Vcur, "Vcur", il);
817
+
818
+                cur = build_attn(inp_attn, gf,
819
+                    model.layers[il].wo, model.layers[il].bo,
820
+                    Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
821
822
823
824
825
826
827
828
829
830
831
832
833
+
+                if (il == n_layer - 1) {
+                    // skip computing output for unused tokens
+                    struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+                    n_tokens = n_outputs;
+                    cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                    inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+                }
+
+                struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+                cb(ffn_inp, "ffn_inp", il);
+
+                // feed-forward network
834
+                cur = build_norm(ffn_inp,
835
+                        model.layers[il].ffn_norm, NULL,
836
+                        LLM_NORM_RMS, il);
837
838
+                cb(cur, "ffn_norm", il);
+
839
+                cur = build_ffn(cur,
840
841
842
843
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                        model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
+                        NULL,
844
+                        LLM_FFN_SILU, LLM_FFN_PAR, il);
845
846
847
848
849
+                cb(cur, "ffn_out", il);
+
+                cur = ggml_add(ctx0, cur, ffn_inp);
+                cb(cur, "ffn_out", il);
+
850
+                cur = build_cvec(cur, il);
851
852
853
854
855
856
857
858
859
+                cb(cur, "l_out", il);
+
+                // input for next layer
+                inpL = cur;
+            }
+        }
+
+        cur = inpL;
+
860
+        cur = build_norm(cur,
861
+                model.output_norm, NULL,
862
+                LLM_NORM_RMS, -1);
863
+        cb(cur, "result_norm", -1);
864
+        res->t_embd = cur;
865
+
866
+        // lm_head
867
868
+        cur = build_lora_mm(model.output, cur);
+
869
+        cb(cur, "result_output", -1);
870
+        res->t_logits = cur;
871
872
873
+
+        ggml_build_forward_expand(gf, cur);
+    }
874
+};
875
+
876
877
878
 struct llm_build_deci : public llm_graph_context {
     llm_build_deci(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
879
880
881
882
883
884
885
886
887
888
@@ -12496,7 +12796,7 @@ struct llm_build_solar : public llm_graph_context {
             // self-attention
             {
                 // rope freq factors for llama3; may return nullptr for llama2 and other models
-                ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
+                ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
 
                 // compute Q and K and RoPE them
                 ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -13128,6 +13428,10 @@ llm_graph_result_ptr llama_model::build_graph(
889
             {
890
                 llm = std::make_unique<llm_build_llama>(*this, params, gf);
891
892
893
             } break;
+        case LLM_ARCH_MLLAMA:
+            {
894
+                llm = std::make_unique<llm_build_mllama>(*this, params, gf);
895
+            } break;
896
         case LLM_ARCH_DECI:
897
             {
898
                 llm = std::make_unique<llm_build_deci>(*this, params, gf);
899
@@ -13489,6 +13793,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
900
901
902
903
904
905
906
907
         // use what we call a normal RoPE, operating on pairs of consecutive head values
         case LLM_ARCH_LLAMA:
         case LLM_ARCH_LLAMA4:
+        case LLM_ARCH_MLLAMA:
         case LLM_ARCH_DECI:
         case LLM_ARCH_BAICHUAN:
         case LLM_ARCH_STARCODER:
diff --git a/src/llama-model.h b/src/llama-model.h
908
index 43746c7d..9281e629 100644
909
910
911
912
913
914
915
--- a/src/llama-model.h
+++ b/src/llama-model.h
@@ -11,6 +11,7 @@
 #include <string>
 #include <unordered_map>
 #include <vector>
+#include <stdexcept>
916
 
917
918
 struct llama_cparams;
 struct llama_ubatch;
919
@@ -74,6 +75,7 @@ enum llm_type {
920
921
922
923
924
     LLM_TYPE_40B,
     LLM_TYPE_65B,
     LLM_TYPE_70B,
+    LLM_TYPE_90B,
     LLM_TYPE_236B,
925
     LLM_TYPE_290B,
926
     LLM_TYPE_314B,
927
@@ -318,6 +320,16 @@ struct llama_layer {
928
 
929
     struct ggml_tensor * bskcn_tv = nullptr;
930
 
931
932
933
934
935
936
937
938
939
940
941
+    // cross attention
+    struct ggml_tensor * cross_attn_k_norm = nullptr;
+    struct ggml_tensor * cross_attn_k_proj = nullptr;
+    struct ggml_tensor * cross_attn_o_proj = nullptr;
+    struct ggml_tensor * cross_attn_q_norm = nullptr;
+    struct ggml_tensor * cross_attn_q_proj = nullptr;
+    struct ggml_tensor * cross_attn_v_proj = nullptr;
+    struct ggml_tensor * cross_attn_attn_gate = nullptr;
+    struct ggml_tensor * cross_attn_mlp_gate = nullptr;
+
     struct llama_layer_posnet posnet;
942
 
943
944
     struct llama_layer_convnext convnext;
diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp
945
index 820d5128..56531980 100644
946
947
948
949
950
951
952
953
954
955
956
--- a/src/llama-quant.cpp
+++ b/src/llama-quant.cpp
@@ -639,7 +639,9 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
         if (llama_model_has_encoder(&model)) {
             n_attn_layer *= 3;
         }
-        GGML_ASSERT((qs.n_attention_wv == n_attn_layer) && "n_attention_wv is unexpected");
+        if (qs.n_attention_wv != n_attn_layer) {
+            LLAMA_LOG_WARN("%s: n_attention_wv is unexpected, expected: %d, found: %d\n", __func__, n_attn_layer, qs.n_attention_wv);
+        }
     }
957
 
958
     size_t total_size_org = 0;
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
diff --git a/tools/mtmd/llava.cpp b/tools/mtmd/llava.cpp
index ebef8b3c..b0eb79bb 100644
--- a/tools/mtmd/llava.cpp
+++ b/tools/mtmd/llava.cpp
@@ -462,7 +462,7 @@ struct llava_embd_batch {
     std::vector<llama_seq_id *> seq_ids;
     std::vector<int8_t>         logits;
     llama_batch batch;
-    llava_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
+    llava_embd_batch(float * embd, int32_t n_embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
         pos     .resize(n_tokens);
         n_seq_id.resize(n_tokens);
         seq_ids .resize(n_tokens + 1);
@@ -474,6 +474,7 @@ struct llava_embd_batch {
             /*n_tokens       =*/ n_tokens,
             /*tokens         =*/ nullptr,
             /*embd           =*/ embd,
+            /*n_embd         =*/ n_embd,
             /*pos            =*/ pos.data(),
             /*n_seq_id       =*/ n_seq_id.data(),
             /*seq_id         =*/ seq_ids.data(),
@@ -497,7 +498,7 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_
             n_eval = n_batch;
         }
         float * embd = image_embed->embed+i*n_embd;
-        llava_embd_batch llava_batch = llava_embd_batch(embd, n_eval, *n_past, 0);
+        llava_embd_batch llava_batch = llava_embd_batch(embd, n_embd, n_eval, *n_past, 0);
         if (llama_decode(ctx_llama, llava_batch.batch)) {
             LOG_ERR("%s : failed to eval\n", __func__);
             return false;
diff --git a/tools/mtmd/mtmd-helper.cpp b/tools/mtmd/mtmd-helper.cpp
index 7a328867..61ebdd43 100644
--- a/tools/mtmd/mtmd-helper.cpp
+++ b/tools/mtmd/mtmd-helper.cpp
@@ -58,7 +58,7 @@ struct decode_embd_batch {
     std::vector<llama_seq_id *> seq_ids;
     std::vector<int8_t>         logits;
     llama_batch batch;
-    decode_embd_batch(float * embd, int32_t n_tokens, int n_pos_per_embd, int n_mmproj_embd) : n_pos_per_embd(n_pos_per_embd), n_mmproj_embd(n_mmproj_embd) {
+    decode_embd_batch(float * embd, int32_t n_embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) : n_pos_per_embd(n_pos_per_embd), n_mmproj_embd(n_mmproj_embd) {
         pos     .resize(n_tokens * n_pos_per_embd);
         n_seq_id.resize(n_tokens);
         seq_ids .resize(n_tokens + 1);
@@ -69,6 +69,7 @@ struct decode_embd_batch {
             /*n_tokens       =*/ n_tokens,
             /*tokens         =*/ nullptr,
             /*embd           =*/ embd,
+            /*n_embd         =*/ n_embd,
             /*pos            =*/ pos.data(),
             /*n_seq_id       =*/ n_seq_id.data(),
             /*seq_id         =*/ seq_ids.data(),
@@ -131,6 +132,7 @@ struct decode_embd_batch {
             /*n_tokens       =*/ n_tokens,
             /*tokens         =*/ nullptr,
             /*embd           =*/ batch.embd     + offset * n_mmproj_embd,
+            /*n_embd         =*/ batch.n_embd,
             /*pos            =*/ pos_ptr,
             /*n_seq_id       =*/ batch.n_seq_id + offset,
             /*seq_id         =*/ batch.seq_id   + offset,
@@ -166,7 +168,8 @@ int32_t mtmd_helper_decode_image_chunk(
     int32_t n_tokens = mtmd_image_tokens_get_n_tokens(image_tokens);
     int32_t i_batch = 0;
     int32_t n_img_batches = GGML_PAD(n_tokens, n_batch) / n_batch;
-    decode_embd_batch batch_embd(encoded_embd, n_tokens, n_pos_per_embd, n_mmproj_embd);
+    int n_embd  = llama_model_n_embd(llama_get_model(lctx));
+    decode_embd_batch batch_embd(encoded_embd, n_embd, n_tokens, n_past, seq_id);
 
     const int nx = mtmd_image_tokens_get_nx(image_tokens);
     const int ny = mtmd_image_tokens_get_ny(image_tokens);