0006-add-mllama-support.patch 44.3 KB
Newer Older
1
2
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: jmorganca <jmorganca@gmail.com>
3
Date: Sun, 20 Apr 2025 16:12:36 -0700
4
5
Subject: [PATCH] add mllama support

6
adds support for the llama 3.2 vision architecture
7
---
8
 examples/llava/gemma3-cli.cpp |   3 +-
9
 examples/llava/llava.cpp      |   5 +-
10
 examples/llava/mtmd.cpp       |   6 +-
11
12
 ggml/src/ggml-backend-reg.cpp |   6 +-
 include/llama.h               |   6 +
13
 src/llama-arch.cpp            |  44 +++++
14
15
 src/llama-arch.h              |  10 ++
 src/llama-batch.cpp           |   3 +
16
17
 src/llama-context.cpp         |  25 ++-
 src/llama-context.h           |   1 +
18
 src/llama-cparams.h           |   1 +
19
20
21
22
23
 src/llama-graph.cpp           |  25 +++
 src/llama-graph.h             |  12 ++
 src/llama-hparams.cpp         |   4 +
 src/llama-hparams.h           |   7 +
 src/llama-kv-cache.cpp        |  12 +-
24
 src/llama-model-loader.cpp    |   2 +
25
 src/llama-model.cpp           | 309 +++++++++++++++++++++++++++++++++-
26
 src/llama-model.h             |  12 ++
27
 src/llama-quant.cpp           |   4 +-
28
 20 files changed, 475 insertions(+), 22 deletions(-)
29

30
diff --git a/examples/llava/gemma3-cli.cpp b/examples/llava/gemma3-cli.cpp
31
index 3d566475..654d1358 100644
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
--- a/examples/llava/gemma3-cli.cpp
+++ b/examples/llava/gemma3-cli.cpp
@@ -106,7 +106,7 @@ struct decode_embd_batch {
     std::vector<llama_seq_id *> seq_ids;
     std::vector<int8_t>         logits;
     llama_batch batch;
-    decode_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
+    decode_embd_batch(float * embd, int32_t n_embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
         pos     .resize(n_tokens);
         n_seq_id.resize(n_tokens);
         seq_ids .resize(n_tokens + 1);
@@ -118,6 +118,7 @@ struct decode_embd_batch {
             /*n_tokens       =*/ n_tokens,
             /*tokens         =*/ nullptr,
             /*embd           =*/ embd,
+            /*n_embd         =*/ n_embd,
             /*pos            =*/ pos.data(),
             /*n_seq_id       =*/ n_seq_id.data(),
             /*seq_id         =*/ seq_ids.data(),
51
diff --git a/examples/llava/llava.cpp b/examples/llava/llava.cpp
52
index 03a22cbb..5eb40bcd 100644
53
54
--- a/examples/llava/llava.cpp
+++ b/examples/llava/llava.cpp
55
@@ -456,7 +456,7 @@ struct llava_embd_batch {
56
57
58
59
60
61
62
63
     std::vector<llama_seq_id *> seq_ids;
     std::vector<int8_t>         logits;
     llama_batch batch;
-    llava_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
+    llava_embd_batch(float * embd, int32_t n_embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
         pos     .resize(n_tokens);
         n_seq_id.resize(n_tokens);
         seq_ids .resize(n_tokens + 1);
64
@@ -468,6 +468,7 @@ struct llava_embd_batch {
65
66
67
68
69
70
71
             /*n_tokens       =*/ n_tokens,
             /*tokens         =*/ nullptr,
             /*embd           =*/ embd,
+            /*n_embd         =*/ n_embd,
             /*pos            =*/ pos.data(),
             /*n_seq_id       =*/ n_seq_id.data(),
             /*seq_id         =*/ seq_ids.data(),
72
@@ -491,7 +492,7 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_
73
74
             n_eval = n_batch;
         }
75
76
77
78
         float * embd = image_embed->embed+i*n_embd;
-        llava_embd_batch llava_batch = llava_embd_batch(embd, n_eval, *n_past, 0);
+        llava_embd_batch llava_batch = llava_embd_batch(embd, n_embd, n_eval, *n_past, 0);
         if (llama_decode(ctx_llama, llava_batch.batch)) {
79
80
             LOG_ERR("%s : failed to eval\n", __func__);
             return false;
81
diff --git a/examples/llava/mtmd.cpp b/examples/llava/mtmd.cpp
82
index 3fd5bebc..f0cec596 100644
83
84
--- a/examples/llava/mtmd.cpp
+++ b/examples/llava/mtmd.cpp
85
@@ -233,7 +233,7 @@ struct decode_embd_batch {
86
87
88
89
90
91
92
93
     std::vector<llama_seq_id *> seq_ids;
     std::vector<int8_t>         logits;
     llama_batch batch;
-    decode_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
+    decode_embd_batch(float * embd, int32_t n_embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
         pos     .resize(n_tokens);
         n_seq_id.resize(n_tokens);
         seq_ids .resize(n_tokens + 1);
94
@@ -245,6 +245,7 @@ struct decode_embd_batch {
95
96
97
98
99
100
101
             /*n_tokens       =*/ n_tokens,
             /*tokens         =*/ nullptr,
             /*embd           =*/ embd,
+            /*n_embd         =*/ n_embd,
             /*pos            =*/ pos.data(),
             /*n_seq_id       =*/ n_seq_id.data(),
             /*seq_id         =*/ seq_ids.data(),
102
@@ -311,7 +312,8 @@ int32_t mtmd_helper_eval(mtmd_context * ctx,
103
 
104
             int32_t n_tokens = mtmd_image_tokens_get_n_tokens(chunk.tokens_image.get());
105
106
107
108
109
110
111
             float * embd = mtmd_get_output_embd(ctx);
-            decode_embd_batch batch_img(embd, n_tokens, n_past, 0);
+            int n_embd  = llama_model_n_embd(llama_get_model(lctx));
+            decode_embd_batch batch_img(embd, n_embd, n_tokens, n_past, 0);
             int64_t t1 = ggml_time_ms();
             ret = llama_decode(lctx, batch_img.batch);
             if (ret != 0) {
112
diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp
113
index 405d8e31..82ae1b5b 100644
114
115
--- a/ggml/src/ggml-backend-reg.cpp
+++ b/ggml/src/ggml-backend-reg.cpp
116
@@ -178,9 +178,9 @@ struct ggml_backend_registry {
117
118
119
120
121
122
123
124
125
126
127
128
 #ifdef GGML_USE_CANN
         register_backend(ggml_backend_cann_reg());
 #endif
-#ifdef GGML_USE_BLAS
-        register_backend(ggml_backend_blas_reg());
-#endif
+// #ifdef GGML_USE_BLAS
+//         register_backend(ggml_backend_blas_reg());
+// #endif
 #ifdef GGML_USE_RPC
         register_backend(ggml_backend_rpc_reg());
 #endif
129
diff --git a/include/llama.h b/include/llama.h
130
index 5657fbf0..f91896e4 100644
131
132
--- a/include/llama.h
+++ b/include/llama.h
133
@@ -255,6 +255,7 @@ extern "C" {
134
135
136
137
138
139
140
 
         llama_token  *  token;
         float        *  embd;
+        int32_t         n_embd;
         llama_pos    *  pos;
         int32_t      *  n_seq_id;
         llama_seq_id ** seq_id;
141
@@ -357,6 +358,7 @@ extern "C" {
142
143
144
145
146
147
148
         bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
         bool flash_attn;  // whether to use flash attention [EXPERIMENTAL]
         bool no_perf;     // whether to measure performance timings
+        bool cross_attn;  // whether to use cross attention
 
         // Abort callback
         // if it returns true, execution of llama_decode() will be aborted
149
@@ -458,6 +460,10 @@ extern "C" {
150
151
             struct llama_context_params   params),
             "use llama_init_from_model instead");
152
153
154
 
+    // TODO (jmorganca): this should most likely be passed in as part of a batch
+    // and not set on the context for all batches.
155
+    LLAMA_API void llama_set_cross_attention(struct llama_context * ctx, bool cross_attn_state);
156
157
158
159
+
     // Frees all allocated memory
     LLAMA_API void llama_free(struct llama_context * ctx);
 
160
diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp
161
index f754bc8f..0568565f 100644
162
163
164
--- a/src/llama-arch.cpp
+++ b/src/llama-arch.cpp
@@ -6,6 +6,7 @@
165
166
 
 static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
167
168
     { LLM_ARCH_LLAMA,            "llama"            },
+    { LLM_ARCH_MLLAMA,           "mllama"           },
169
     { LLM_ARCH_LLAMA4,           "llama4"           },
170
171
     { LLM_ARCH_DECI,             "deci"             },
     { LLM_ARCH_FALCON,           "falcon"           },
172
173
174
175
176
@@ -142,6 +143,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
     { LLM_KV_ATTENTION_SLIDING_WINDOW,               "%s.attention.sliding_window"               },
     { LLM_KV_ATTENTION_SCALE,                        "%s.attention.scale"                        },
     { LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION,        "%s.attention.block_skip_connection"        },
+    { LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS,       "%s.attention.cross_attention_layers"       },
177
178
     { LLM_KV_ATTENTION_KEY_LENGTH_MLA,               "%s.attention.key_length_mla"               },
     { LLM_KV_ATTENTION_VALUE_LENGTH_MLA,             "%s.attention.value_length_mla"             },
179
 
180
@@ -271,6 +273,40 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
181
             { LLM_TENSOR_FFN_UP_SHEXP,    "blk.%d.ffn_up_shexp" },
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
         },
     },
+    {
+        LLM_ARCH_MLLAMA,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
+            { LLM_TENSOR_OUTPUT,          "output" },
+            { LLM_TENSOR_ROPE_FREQS,      "rope_freqs" },
+            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
+            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
+            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
+            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
+            { LLM_TENSOR_ATTN_ROT_EMBD,   "blk.%d.attn_rot_embd" },
+            { LLM_TENSOR_FFN_GATE_INP,    "blk.%d.ffn_gate_inp" },
+            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
+            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
+            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
+            { LLM_TENSOR_FFN_GATE_EXP,    "blk.%d.ffn_gate.%d" },
+            { LLM_TENSOR_FFN_DOWN_EXP,    "blk.%d.ffn_down.%d" },
+            { LLM_TENSOR_FFN_UP_EXP,      "blk.%d.ffn_up.%d" },
+            { LLM_TENSOR_FFN_GATE_EXPS,   "blk.%d.ffn_gate_exps" },
+            { LLM_TENSOR_FFN_DOWN_EXPS,   "blk.%d.ffn_down_exps" },
+            { LLM_TENSOR_FFN_UP_EXPS,     "blk.%d.ffn_up_exps" },
+            { LLM_TENSOR_CROSS_ATTN_K_NORM,    "blk.%d.cross_attn_k_norm" },
+            { LLM_TENSOR_CROSS_ATTN_K_PROJ,    "blk.%d.cross_attn_k_proj" },
+            { LLM_TENSOR_CROSS_ATTN_O_PROJ,    "blk.%d.cross_attn_o_proj" },
+            { LLM_TENSOR_CROSS_ATTN_Q_NORM,    "blk.%d.cross_attn_q_norm" },
+            { LLM_TENSOR_CROSS_ATTN_Q_PROJ,    "blk.%d.cross_attn_q_proj" },
+            { LLM_TENSOR_CROSS_ATTN_V_PROJ,    "blk.%d.cross_attn_v_proj" },
+            { LLM_TENSOR_CROSS_ATTN_ATTN_GATE, "blk.%d.cross_attn_attn_gate" },
+            { LLM_TENSOR_CROSS_ATTN_MLP_GATE,  "blk.%d.cross_attn_mlp_gate" },
+        },
+    },
     {
219
         LLM_ARCH_DECI,
220
         {
221
@@ -1681,6 +1717,14 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
     // this tensor is loaded for T5, but never used
     {LLM_TENSOR_DEC_CROSS_ATTN_REL_B,       {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}},
     {LLM_TENSOR_BSKCN_TV,                   {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
+    {LLM_TENSOR_CROSS_ATTN_K_NORM,          {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
+    {LLM_TENSOR_CROSS_ATTN_K_PROJ,          {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_CROSS_ATTN_O_PROJ,          {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_CROSS_ATTN_Q_NORM,          {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
+    {LLM_TENSOR_CROSS_ATTN_Q_PROJ,          {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_CROSS_ATTN_V_PROJ,          {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_CROSS_ATTN_ATTN_GATE,       {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
+    {LLM_TENSOR_CROSS_ATTN_MLP_GATE,        {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
     {LLM_TENSOR_CONV1D,                     {LLM_TENSOR_LAYER_INPUT,     GGML_OP_IM2COL}},
     {LLM_TENSOR_POS_NET_NORM,               {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
     {LLM_TENSOR_POS_NET_NORM1,              {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
diff --git a/src/llama-arch.h b/src/llama-arch.h
237
index 439aaeab..6a989034 100644
238
239
--- a/src/llama-arch.h
+++ b/src/llama-arch.h
240
@@ -11,6 +11,7 @@
241
242
 enum llm_arch {
     LLM_ARCH_LLAMA,
243
     LLM_ARCH_LLAMA4,
244
245
246
247
+    LLM_ARCH_MLLAMA,
     LLM_ARCH_DECI,
     LLM_ARCH_FALCON,
     LLM_ARCH_BAICHUAN,
248
@@ -146,6 +147,7 @@ enum llm_kv {
249
250
251
252
     LLM_KV_ATTENTION_SLIDING_WINDOW,
     LLM_KV_ATTENTION_SCALE,
     LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION,
+    LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS,
253
254
     LLM_KV_ATTENTION_KEY_LENGTH_MLA,
     LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
255
 
256
@@ -347,6 +349,14 @@ enum llm_tensor {
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
     LLM_TENSOR_CLS,
     LLM_TENSOR_CLS_OUT,
     LLM_TENSOR_BSKCN_TV,
+    LLM_TENSOR_CROSS_ATTN_K_NORM,
+    LLM_TENSOR_CROSS_ATTN_K_PROJ,
+    LLM_TENSOR_CROSS_ATTN_O_PROJ,
+    LLM_TENSOR_CROSS_ATTN_Q_NORM,
+    LLM_TENSOR_CROSS_ATTN_Q_PROJ,
+    LLM_TENSOR_CROSS_ATTN_V_PROJ,
+    LLM_TENSOR_CROSS_ATTN_ATTN_GATE,
+    LLM_TENSOR_CROSS_ATTN_MLP_GATE,
     LLM_TENSOR_CONV1D,
     LLM_TENSOR_CONVNEXT_DW,
     LLM_TENSOR_CONVNEXT_NORM,
diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp
index 01d5ca57..8682b0e6 100644
--- a/src/llama-batch.cpp
+++ b/src/llama-batch.cpp
@@ -316,6 +316,7 @@ struct llama_batch llama_batch_get_one(
         /*n_tokens       =*/ n_tokens,
         /*tokens         =*/ tokens,
         /*embd           =*/ nullptr,
+        /*n_embd         =*/ 0,
         /*pos            =*/ nullptr,
         /*n_seq_id       =*/ nullptr,
         /*seq_id         =*/ nullptr,
@@ -328,6 +329,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
         /*n_tokens       =*/ 0,
         /*tokens         =*/ nullptr,
         /*embd           =*/ nullptr,
+        /*n_embd         =*/ 0,
         /*pos            =*/ nullptr,
         /*n_seq_id       =*/ nullptr,
         /*seq_id         =*/ nullptr,
@@ -336,6 +338,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
 
     if (embd) {
         batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
+        batch.n_embd = embd;
     } else {
         batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
     }
diff --git a/src/llama-context.cpp b/src/llama-context.cpp
300
index 32f59819..0343ba8a 100644
301
302
--- a/src/llama-context.cpp
+++ b/src/llama-context.cpp
303
@@ -862,7 +862,7 @@ float * llama_context::get_logits_ith(int32_t i) {
304
305
             throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, n_outputs));
         }
306
 
307
308
309
310
311
-        return logits + j*model.vocab.n_tokens();
+        return logits + j*model.hparams.n_vocab;
     } catch (const std::exception & err) {
         LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what());
 #ifndef NDEBUG
312
@@ -983,6 +983,10 @@ void llama_context::set_warmup(bool value) {
313
314
     cparams.warmup = value;
 }
315
 
316
317
318
319
320
321
322
+void llama_context::set_cross_attn(bool value) {
+    cparams.cross_attn = value;
+}
+
 void llama_context::set_adapter_lora(
             llama_adapter_lora * adapter,
             float scale) {
323
@@ -1058,7 +1062,7 @@ int llama_context::encode(llama_batch & inp_batch) {
324
 
325
326
327
328
329
330
331
     const int64_t n_embd = hparams.n_embd;
 
-    sbatch.from_batch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
+    sbatch.from_batch(batch, batch.n_embd, /* simple_split */ true, /* logits_all */ true);
 
     const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
 
332
@@ -1198,10 +1202,9 @@ int llama_context::decode(llama_batch & inp_batch) {
333
 
334
335
336
337
338
339
340
341
342
343
     const llama_batch & batch = batch_allocr.batch;
 
-    const auto & vocab   = model.vocab;
     const auto & hparams = model.hparams;
 
-    const int32_t n_vocab = vocab.n_tokens();
+    const int32_t n_vocab = hparams.n_vocab;
 
     const int64_t n_tokens_all = batch.n_tokens;
     const int64_t n_embd       = hparams.n_embd;
344
@@ -1249,7 +1252,7 @@ int llama_context::decode(llama_batch & inp_batch) {
345
346
347
348
349
350
351
352
 
     const bool logits_all = n_outputs_all == n_tokens_all;
 
-    sbatch.from_batch(batch, n_embd,
+    sbatch.from_batch(batch, batch.n_embd,
             /* simple_split */ !kv_self->recurrent,
             /* logits_all   */ logits_all);
 
353
@@ -1483,12 +1486,11 @@ int llama_context::decode(llama_batch & inp_batch) {
354
355
356
357
358
359
 
 int32_t llama_context::output_reserve(int32_t n_outputs) {
     const auto & hparams = model.hparams;
-    const auto & vocab   = model.vocab;
 
     const int64_t n_outputs_max = std::max<int64_t>(n_outputs, n_seq_max());
360
361
362
363
364
365
366
 
     const auto n_batch = cparams.n_batch;
-    const auto n_vocab = vocab.n_tokens();
+    const auto n_vocab = hparams.n_vocab;
     const auto n_embd  = hparams.n_embd;
 
     // TODO: use a per-batch flag for logits presence instead
367
@@ -1558,7 +1560,7 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
368
369
 void llama_context::output_reorder() {
     auto & out_ids = sbatch.out_ids;
370
     if (!out_ids.empty()) {
371
372
373
374
375
-        const uint32_t n_vocab = model.vocab.n_tokens();
+        const uint32_t n_vocab = model.hparams.n_vocab;
         const uint32_t n_embd  = model.hparams.n_embd;
 
         GGML_ASSERT((size_t) n_outputs == out_ids.size());
376
@@ -2065,7 +2067,7 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
377
378
379
380
381
     {
         LLAMA_LOG_DEBUG("%s: - writing logits\n", __func__);
 
-        const uint64_t logits_size = std::min((uint64_t) this->logits_size, (uint64_t) n_outputs * model.vocab.n_tokens());
+        const uint64_t logits_size = std::min((uint64_t) this->logits_size, (uint64_t) n_outputs * model.hparams.n_vocab);
382
 
383
384
         io.write(&logits_size, sizeof(logits_size));
 
385
@@ -2248,6 +2250,7 @@ llama_context_params llama_context_default_params() {
386
387
388
389
390
391
392
         /*.offload_kqv                 =*/ true,
         /*.flash_attn                  =*/ false,
         /*.no_perf                     =*/ true,
+        /*.cross_attn                  =*/ false,
         /*.abort_callback              =*/ nullptr,
         /*.abort_callback_data         =*/ nullptr,
     };
393
@@ -2375,6 +2378,10 @@ void llama_set_warmup(llama_context * ctx, bool warmup) {
394
     ctx->set_warmup(warmup);
395
396
397
 }
 
+void llama_set_cross_attention(struct llama_context * ctx, bool cross_attention) {
398
+    ctx->set_cross_attn(cross_attention);
399
+}
400
+
401
402
403
 void llama_synchronize(llama_context * ctx) {
     ctx->synchronize();
 }
404
diff --git a/src/llama-context.h b/src/llama-context.h
405
index 04facb54..baa03276 100644
406
407
--- a/src/llama-context.h
+++ b/src/llama-context.h
408
409
410
411
412
@@ -65,6 +65,7 @@ struct llama_context {
     void set_embeddings (bool value);
     void set_causal_attn(bool value);
     void set_warmup(bool value);
+    void set_cross_attn(bool value);
413
 
414
415
     void set_adapter_lora(
             llama_adapter_lora * adapter,
416
diff --git a/src/llama-cparams.h b/src/llama-cparams.h
417
index 30e550f0..85ad91b9 100644
418
419
420
--- a/src/llama-cparams.h
+++ b/src/llama-cparams.h
@@ -29,6 +29,7 @@ struct llama_cparams {
421
422
423
     bool offload_kqv;
     bool flash_attn;
     bool no_perf;
424
+    bool cross_attn;
425
     bool warmup;
426
427
 
     enum llama_pooling_type pooling_type;
428
diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp
429
index a85e9728..d740c120 100644
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
--- a/src/llama-graph.cpp
+++ b/src/llama-graph.cpp
@@ -546,6 +546,12 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
     }
 }
 
+void llm_graph_input_cross_attn_state::set_input(const llama_ubatch * ubatch) {
+    if (ubatch->embd) {
+        ggml_backend_tensor_set(cross_attn_state, ubatch->embd, 0, ggml_nbytes(cross_attn_state));
+    }
+}
+
 //
 // llm_graph_context
 //
445
@@ -1506,6 +1512,25 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
446
447
     return (llm_graph_input_attn_cross *) res->add_input(std::move(inp));
 }
448
 
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
+ggml_tensor * llm_graph_context::build_inp_cross_attn_state() const {
+    const int64_t n_embd = hparams.n_embd;
+
+    auto inp = std::make_unique<llm_graph_input_cross_attn_state>();
+
+    ggml_tensor * cur = nullptr;
+
+    inp->cross_attn_state = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd, 1601, 4);
+    ggml_set_input(inp->cross_attn_state);
+
+    cur = inp->cross_attn_state;
+
+    cb(cur, "inp_cross_attn_state", -1);
+
+    res->add_input(std::move(inp));
+
+    return cur;
+}
+
 ggml_tensor * llm_graph_context::build_attn(
         llm_graph_input_attn_cross * inp,
         ggml_cgraph * gf,
diff --git a/src/llama-graph.h b/src/llama-graph.h
472
index d192dc14..260a2af2 100644
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
--- a/src/llama-graph.h
+++ b/src/llama-graph.h
@@ -86,6 +86,7 @@ public:
 
     ggml_tensor * tokens = nullptr; // I32 [n_batch]
     ggml_tensor * embd   = nullptr; // F32 [n_embd, n_batch]
+    ggml_tensor * cross_attn_state; // F32 [4, n_embd, 1061]
 };
 
 class llm_graph_input_pos : public llm_graph_input_i {
@@ -285,6 +286,16 @@ public:
     const llama_cross * cross = nullptr;
 };
 
+class llm_graph_input_cross_attn_state : public llm_graph_input_i {
+public:
+    llm_graph_input_cross_attn_state()          = default;
+    virtual ~llm_graph_input_cross_attn_state() = default;
+
+    void set_input(const llama_ubatch * ubatch) override;
+
+    ggml_tensor * cross_attn_state; // F32 [4, n_embd, 1061]
+};
+
 //
 // llm_graph_result
 //
@@ -493,6 +504,7 @@ struct llm_graph_context {
     ggml_tensor * build_inp_cls() const;
     ggml_tensor * build_inp_s_copy() const;
     ggml_tensor * build_inp_s_mask() const;
+    ggml_tensor * build_inp_cross_attn_state() const;
 
     ggml_tensor * build_inp_cross_embd() const;
     ggml_tensor * build_inp_pos_bucket_enc() const;
508
diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp
509
index 8a667960..6a02de03 100644
510
511
--- a/src/llama-hparams.cpp
+++ b/src/llama-hparams.cpp
512
@@ -85,3 +85,7 @@ bool llama_hparams::is_swa(uint32_t il) const {
513
 
514
     GGML_ABORT("fatal error");
515
 }
516
+
517
518
+bool llama_hparams::cross_attention_layers(uint32_t il) const {
+    return std::find(cross_attn_layers.begin(), cross_attn_layers.end(), il) != cross_attn_layers.end();
519
+}
520
diff --git a/src/llama-hparams.h b/src/llama-hparams.h
521
index 6e278945..c8a34d52 100644
522
523
--- a/src/llama-hparams.h
+++ b/src/llama-hparams.h
524
525
526
527
528
529
530
531
532
533
@@ -2,6 +2,8 @@
 
 #include "llama.h"
 
+#include <algorithm>
+
 #include <array>
 
 // bump if necessary
@@ -42,6 +44,7 @@ struct llama_hparams {
534
535
536
537
538
     uint32_t n_expert = 0;
     uint32_t n_expert_used = 0;
     uint32_t n_rel_attn_bkts = 0;
+    uint32_t n_vocab = 0;
 
539
540
541
     // note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA
     uint32_t n_embd_head_k_mla = 0;
@@ -56,6 +59,7 @@ struct llama_hparams {
542
543
544
545
546
547
548
     std::array<uint32_t, LLAMA_MAX_LAYERS> n_ff_arr;
 
     std::array<std::array<uint32_t, LLAMA_MAX_LAYERS>, 4> n_bskcn_arr = {};
+    std::array<uint32_t, LLAMA_MAX_LAYERS> cross_attn_layers;
 
     uint32_t n_layer_dense_lead = 0;
     uint32_t n_lora_q           = 0;
549
@@ -158,6 +162,9 @@ struct llama_hparams {
550
551
     // Block skip connection
     bool n_bskcn(uint32_t n, uint32_t il) const;
552
 
553
+    // cross attention layers
554
+    bool cross_attention_layers(uint32_t il) const;
555
556
+
     bool is_swa(uint32_t il) const;
557
558
 };
 
559
diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp
560
index 7c9d46d8..69f8d35a 100644
561
562
--- a/src/llama-kv-cache.cpp
+++ b/src/llama-kv-cache.cpp
563
@@ -95,8 +95,16 @@ bool llama_kv_cache_unified::init(
564
565
             return false;
         }
566
 
567
568
569
570
-        ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
-        ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
+        ggml_tensor * k, *v;
+
571
+        // for cross attention layers
572
+        if (model.arch == LLM_ARCH_MLLAMA && hparams.cross_attention_layers(i)) {
573
574
575
576
577
+            k = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_k, 6404, hparams.n_head_kv(i));
+            v = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_v, 6404, hparams.n_head_kv(i));
+        } else {
+            k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
+            v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
578
+        }
579
580
         ggml_format_name(k, "cache_k_l%d", i);
         ggml_format_name(v, "cache_v_l%d", i);
581
         k_l.push_back(k);
582
diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp
583
index a012aeae..2e11507d 100644
584
585
--- a/src/llama-model-loader.cpp
+++ b/src/llama-model-loader.cpp
586
@@ -315,6 +315,8 @@ namespace GGUFMeta {
587
         return true;
588
589
     }
 
590
591
592
593
594
595
+    template bool llama_model_loader::get_arr<std::array<unsigned int, 512>>(enum llm_kv kid, std::array<unsigned int, 512>& result, bool required);
+
     template<typename T, size_t N_MAX>
     bool llama_model_loader::get_arr(const std::string & key, std::array<T, N_MAX> & result, bool required) {
         const int kid = gguf_find_key(meta.get(), key.c_str());
diff --git a/src/llama-model.cpp b/src/llama-model.cpp
596
index aba42819..d051696c 100644
597
598
--- a/src/llama-model.cpp
+++ b/src/llama-model.cpp
599
@@ -419,6 +419,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
600
601
602
603
604
605
606
 
     // get general kv
     ml.get_key(LLM_KV_GENERAL_NAME, name, false);
+    ml.get_key(LLM_KV_VOCAB_SIZE, hparams.n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, hparams.n_vocab, false);
 
     // everything past this point is not vocab-related
     if (hparams.vocab_only) {
607
@@ -430,6 +431,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
608
609
610
611
612
613
614
     ml.get_key(LLM_KV_BLOCK_COUNT,       hparams.n_layer);
     ml.get_key(LLM_KV_EXPERT_COUNT,      hparams.n_expert,      false);
     ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false);
+    ml.get_key(LLM_KV_VOCAB_SIZE,        hparams.n_vocab,       false);
 
     if (arch == LLM_ARCH_WAVTOKENIZER_DEC) {
         ml.get_key(LLM_KV_FEATURES_LENGTH, hparams.n_embd_features);
615
@@ -453,9 +455,11 @@ void llama_model::load_hparams(llama_model_loader & ml) {
616
617
618
     std::fill(hparams.n_head_arr.begin(),    hparams.n_head_arr.end(),    0);
     std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0);
     std::fill(hparams.n_ff_arr.begin(),      hparams.n_ff_arr.end(),      0);
619
+    std::fill(hparams.cross_attn_layers.begin(), hparams.cross_attn_layers.end(), -1);
620
 
621
622
     ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH,  hparams.n_ff_arr,   hparams.n_layer, false);
     ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer, false);
623
624
625
626
+    ml.get_arr(LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS, hparams.cross_attn_layers, false);
 
     // n_head_kv is optional, default to n_head
     hparams.n_head_kv_arr = hparams.n_head_arr;
627
@@ -508,7 +512,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
628
629
630
 
         ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
 
631
632
-        if (arch == LLM_ARCH_LLAMA || arch == LLM_ARCH_DECI || arch == LLM_ARCH_FALCON) {
+        if (arch == LLM_ARCH_LLAMA || arch == LLM_ARCH_MLLAMA || arch == LLM_ARCH_DECI || arch == LLM_ARCH_FALCON) {
633
634
635
             if (hparams.n_rot != hparams.n_embd_head_k) {
                 throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k));
             }
636
637
@@ -571,6 +575,16 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                     hparams.use_kq_norm = false;
638
639
640
641
642
643
644
                 }
             } break;
+        case LLM_ARCH_MLLAMA:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+
+                switch (hparams.n_layer) {
645
646
647
+                    case 40: type = LLM_TYPE_11B; break;
+                    case 100: type = LLM_TYPE_90B; break;
+                    default: type = LLM_TYPE_UNKNOWN;
648
649
+                }
+            } break;
650
         case LLM_ARCH_DECI:
651
652
             {
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
653
@@ -1550,7 +1564,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
654
655
656
657
658
659
660
661
         const int64_t n_embd_head_v = hparams.n_embd_head_v;
         const int64_t n_ff          = hparams.n_ff();
         const int64_t n_embd_gqa    = n_embd_v_gqa;
-        const int64_t n_vocab       = vocab.n_tokens();
+        const int64_t n_vocab       = hparams.n_vocab;
         const int64_t n_token_types = vocab.n_token_types();
         const int64_t n_rot         = hparams.n_rot;
         const int64_t n_expert      = hparams.n_expert;
662
@@ -1803,6 +1817,52 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
663
                         }
664
665
666
667
                     }
                 } break;
+            case LLM_ARCH_MLLAMA:
+                {
668
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab+8}, 0);
669
670
671
+
+                    // output
+                    {
672
673
+                        output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                        output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
674
675
+
+                        // if output is NULL, init from the input tok embed
676
677
+                        if (output == NULL) {
+                            output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
678
679
680
681
+                        }
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
682
+                        auto & layer = layers[i];
683
+
684
+                        if (hparams.cross_attention_layers(i)) {
685
686
687
688
689
690
691
692
693
694
695
696
697
+                            layer.cross_attn_k_norm = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_K_NORM,   "weight", i), {128}, 0);
+                            layer.cross_attn_k_proj = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_K_PROJ,   "weight", i), {n_embd, 1024}, 0);
+                            layer.cross_attn_o_proj = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_O_PROJ,   "weight", i), {n_embd, n_embd}, 0);
+                            layer.cross_attn_q_norm = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_Q_NORM, "weight", i), {128}, 0);
+                            layer.cross_attn_q_proj = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_Q_PROJ, "weight", i), {n_embd, n_embd}, 0);
+                            layer.cross_attn_v_proj = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_V_PROJ, "weight", i), {n_embd, 1024}, 0);
+                            layer.cross_attn_attn_gate = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_ATTN_GATE, i), {1}, 0);
+                            layer.cross_attn_mlp_gate = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_MLP_GATE, i), {1}, 0);
+                            layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+                            layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
+                            layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                            layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                            layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
698
+                        } else {
699
700
701
702
703
704
705
706
707
708
+                            layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+                            layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
+                            layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
+                            layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
+                            layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
+                            layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+                            layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
+                            layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                            layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                            layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
709
710
711
+                        }
+                    }
+                } break;
712
             case LLM_ARCH_DECI:
713
                 {
714
                     tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
715
@@ -4683,6 +4743,246 @@ struct llm_build_llama : public llm_graph_context {
716
     }
717
 };
718
 
719
720
+struct llm_build_mllama: public llm_graph_context {
+    llm_build_mllama(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
721
722
723
724
725
726
727
+        // mutable variable, needed during the last layer of the computation to skip unused tokens
+        int32_t n_tokens = this->n_tokens;
+
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
728
729
730
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+        ggml_tensor * inpCAS;
731
+
732
733
+        inpL = build_inp_embd(model.tok_embd);
+        inpCAS = build_inp_cross_attn_state();
734
+
735
736
+          // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
737
+
738
739
+        auto * inp_attn = build_attn_inp_kv_unified();
+        const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
740
741
+
+        for (int il = 0; il < n_layer; ++il) {
742
+            ggml_tensor * inpSA = inpL;
743
744
+
+            // norm
745
+            cur = build_norm(inpL,
746
+                    model.layers[il].attn_norm, NULL,
747
+                    LLM_NORM_RMS, il);
748
749
+            cb(cur, "attn_norm", il);
+
750
+            if (hparams.cross_attention_layers(il)) {
751
+                if (!ubatch.embd && !cparams.cross_attn) {
752
753
754
755
+                    continue;
+                }
+
+                // cross attention layer
756
+                ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_q_proj, cur);
757
758
759
760
761
+                cb(Qcur, "Qcur", il);
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+                cb(Qcur, "Qcur", il);
+
762
+                Qcur = ggml_cont(ctx0, ggml_permute(ctx0, Qcur, 0, 2, 1, 3));
763
764
+                cb(Qcur, "Qcur", il);
+
765
+                Qcur = build_norm(Qcur, model.layers[il].cross_attn_q_norm, NULL, LLM_NORM_RMS, il);
766
767
+                cb(Qcur, "Qcur", il);
+
768
+                ggml_tensor * Kcur, * Vcur;
769
+                if (ubatch.embd) {
770
771
772
773
774
775
+                    Kcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_k_proj, inpCAS);
+                    cb(Kcur, "Kcur", il);
+
+                    Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, 6404);
+                    cb(Kcur, "Kcur", il);
+
776
+                    Kcur = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
777
778
+                    cb(Kcur, "Kcur", il);
+
779
+                    Kcur = build_norm(Kcur, model.layers[il].cross_attn_k_norm, NULL, LLM_NORM_RMS, il);
780
781
+                    cb(Kcur, "Kcur", il);
+
782
+                    ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, kv_self->k_l[il]));
783
784
785
786
787
788
789
790
791
792
+
+                    Vcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_v_proj, inpCAS);
+                    cb(Vcur, "Vcur", il);
+
+                    Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, 6404);
+                    cb(Vcur, "Vcur", il);
+
+                    Vcur = ggml_permute(ctx0, Vcur, 0, 2, 1, 3);
+                    cb(Vcur, "Vcur", il);
+
793
+                    ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, kv_self->v_l[il]));
794
+                } else {
795
+                    Kcur = ggml_view_tensor(ctx0, kv_self->k_l[il]);
796
797
+                    cb(Kcur, "Kcur (view)", il);
+
798
+                    Vcur = ggml_view_tensor(ctx0, kv_self->v_l[il]);
799
800
801
802
803
804
805
+                    cb(Vcur, "Vcur (view)", il);
+                }
+
+                struct ggml_tensor * kq = ggml_mul_mat(ctx0, Kcur, Qcur);
+                cb(kq, "kq", il);
+
+                // TODO: apply causal masks
806
+                struct ggml_tensor * kq_soft_max = ggml_soft_max_ext(ctx0, kq, nullptr, 1.f/sqrtf(float(n_embd_head)), hparams.f_max_alibi_bias);
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
+                cb(kq_soft_max, "kq_soft_max", il);
+
+                Vcur = ggml_cont(ctx0, ggml_transpose(ctx0, Vcur));
+                cb(Vcur, "Vcur", il);
+
+                struct ggml_tensor * kqv = ggml_mul_mat(ctx0, Vcur, kq_soft_max);
+                cb(kqv, "kqv", il);
+
+                struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
+                cb(kqv_merged, "kqv_merged", il);
+
+                cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens);
+                cb(cur, "kqv_merged_cont", il);
+
+                cur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_o_proj, cur);
+                cb(cur, "cur", il);
+
+                // TODO: do this in place once?
+                cur = ggml_mul(ctx0, cur, ggml_tanh(ctx0, model.layers[il].cross_attn_attn_gate));
+
+                struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+                cb(ffn_inp, "ffn_inp", il);
+
+                // feed-forward network
831
+                cur = build_norm(ffn_inp,
832
+                        model.layers[il].ffn_norm, NULL,
833
+                        LLM_NORM_RMS, il);
834
835
+                cb(cur, "ffn_norm", il);
+
836
+                cur = build_ffn(cur,
837
838
839
840
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                        model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
+                        NULL,
841
+                        LLM_FFN_SILU, LLM_FFN_PAR, il);
842
843
844
845
846
847
+                cb(cur, "ffn_out", il);
+
+                // TODO: do this inplace once?
+                cur = ggml_add_inplace(ctx0, ggml_mul_inplace(ctx0, cur, ggml_tanh(ctx0, model.layers[il].cross_attn_mlp_gate)), ffn_inp);
+                cb(cur, "ffn_out", il);
+
848
+                cur = build_cvec(cur, il);
849
850
851
852
853
854
855
856
+                cb(cur, "l_out", il);
+
+                // input for next layer
+                inpL = cur;
+            } else {
+                // self attention layer
+
+                // rope freq factors for llama3; may return nullptr for llama2 and other models
857
+                ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
858
859
+
+                // compute Q and K and RoPE them
860
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
861
862
863
864
865
866
+                cb(Qcur, "Qcur", il);
+                if (model.layers[il].bq) {
+                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                    cb(Qcur, "Qcur", il);
+                }
+
867
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
868
869
870
871
872
873
+                cb(Kcur, "Kcur", il);
+                if (model.layers[il].bk) {
+                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                    cb(Kcur, "Kcur", il);
+                }
+
874
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
875
876
877
878
879
880
+                cb(Vcur, "Vcur", il);
+                if (model.layers[il].bv) {
+                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                    cb(Vcur, "Vcur", il);
+                }
+
881
882
883
884
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
885
+                Qcur = ggml_rope_ext(
886
887
888
889
+                        ctx0, Qcur, inp_pos, rope_factors,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
890
891
+
+                Kcur = ggml_rope_ext(
892
893
894
895
+                        ctx0, Kcur, inp_pos, rope_factors,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
896
+
897
+                cb(Qcur, "Qcur", il);
898
+                cb(Kcur, "Kcur", il);
899
+                cb(Vcur, "Vcur", il);
900
+
901
+                cur = build_attn(inp_attn, gf,
902
+                    model.layers[il].wo, model.layers[il].bo,
903
+                    Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
904
905
906
907
908
909
910
911
912
913
914
915
916
+
+                if (il == n_layer - 1) {
+                    // skip computing output for unused tokens
+                    struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+                    n_tokens = n_outputs;
+                    cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                    inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+                }
+
+                struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+                cb(ffn_inp, "ffn_inp", il);
+
+                // feed-forward network
917
+                cur = build_norm(ffn_inp,
918
+                        model.layers[il].ffn_norm, NULL,
919
+                        LLM_NORM_RMS, il);
920
921
+                cb(cur, "ffn_norm", il);
+
922
+                cur = build_ffn(cur,
923
924
925
926
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                        model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
+                        NULL,
927
+                        LLM_FFN_SILU, LLM_FFN_PAR, il);
928
929
930
931
932
+                cb(cur, "ffn_out", il);
+
+                cur = ggml_add(ctx0, cur, ffn_inp);
+                cb(cur, "ffn_out", il);
+
933
+                cur = build_cvec(cur, il);
934
935
936
937
938
939
940
941
942
+                cb(cur, "l_out", il);
+
+                // input for next layer
+                inpL = cur;
+            }
+        }
+
+        cur = inpL;
+
943
+        cur = build_norm(cur,
944
+                model.output_norm, NULL,
945
+                LLM_NORM_RMS, -1);
946
+        cb(cur, "result_norm", -1);
947
+        res->t_embd = cur;
948
+
949
+        // lm_head
950
951
+        cur = build_lora_mm(model.output, cur);
+
952
+        cb(cur, "result_output", -1);
953
+        res->t_logits = cur;
954
955
956
+
+        ggml_build_forward_expand(gf, cur);
+    }
957
+};
958
+
959
960
961
 struct llm_build_deci : public llm_graph_context {
     llm_build_deci(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
962
@@ -13017,6 +13317,10 @@ llm_graph_result_ptr llama_model::build_graph(
963
             {
964
                 llm = std::make_unique<llm_build_llama>(*this, params, gf);
965
966
967
             } break;
+        case LLM_ARCH_MLLAMA:
+            {
968
+                llm = std::make_unique<llm_build_mllama>(*this, params, gf);
969
+            } break;
970
         case LLM_ARCH_DECI:
971
             {
972
                 llm = std::make_unique<llm_build_deci>(*this, params, gf);
973
@@ -13377,6 +13681,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
974
975
976
977
978
979
980
981
         // use what we call a normal RoPE, operating on pairs of consecutive head values
         case LLM_ARCH_LLAMA:
         case LLM_ARCH_LLAMA4:
+        case LLM_ARCH_MLLAMA:
         case LLM_ARCH_DECI:
         case LLM_ARCH_BAICHUAN:
         case LLM_ARCH_STARCODER:
diff --git a/src/llama-model.h b/src/llama-model.h
982
index 5865d5e9..72bab5be 100644
983
984
985
986
987
988
989
--- a/src/llama-model.h
+++ b/src/llama-model.h
@@ -11,6 +11,7 @@
 #include <string>
 #include <unordered_map>
 #include <vector>
+#include <stdexcept>
990
 
991
992
993
994
995
996
997
998
999
1000
 struct llama_cparams;
 struct llama_ubatch;
@@ -70,6 +71,7 @@ enum llm_type {
     LLM_TYPE_40B,
     LLM_TYPE_65B,
     LLM_TYPE_70B,
+    LLM_TYPE_90B,
     LLM_TYPE_236B,
     LLM_TYPE_314B,
     LLM_TYPE_671B,
1001
@@ -310,6 +312,16 @@ struct llama_layer {
1002
 
1003
     struct ggml_tensor * bskcn_tv = nullptr;
1004
 
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
+    // cross attention
+    struct ggml_tensor * cross_attn_k_norm = nullptr;
+    struct ggml_tensor * cross_attn_k_proj = nullptr;
+    struct ggml_tensor * cross_attn_o_proj = nullptr;
+    struct ggml_tensor * cross_attn_q_norm = nullptr;
+    struct ggml_tensor * cross_attn_q_proj = nullptr;
+    struct ggml_tensor * cross_attn_v_proj = nullptr;
+    struct ggml_tensor * cross_attn_attn_gate = nullptr;
+    struct ggml_tensor * cross_attn_mlp_gate = nullptr;
+
     struct llama_layer_posnet posnet;
1016
 
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
     struct llama_layer_convnext convnext;
diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp
index 7dc54227..223e1f3f 100644
--- a/src/llama-quant.cpp
+++ b/src/llama-quant.cpp
@@ -639,7 +639,9 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
         if (llama_model_has_encoder(&model)) {
             n_attn_layer *= 3;
         }
-        GGML_ASSERT((qs.n_attention_wv == n_attn_layer) && "n_attention_wv is unexpected");
+        if (qs.n_attention_wv != n_attn_layer) {
+            LLAMA_LOG_WARN("%s: n_attention_wv is unexpected, expected: %d, found: %d\n", __func__, n_attn_layer, qs.n_attention_wv);
+        }
     }
1031
 
1032
     size_t total_size_org = 0;