0006-add-mllama-support.patch 43.5 KB
Newer Older
1
2
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: jmorganca <jmorganca@gmail.com>
3
Date: Sun, 20 Apr 2025 16:12:36 -0700
4
5
Subject: [PATCH] add mllama support

6
adds support for the llama 3.2 vision architecture
7
---
8
 examples/llava/llava.cpp      |   5 +-
9
 examples/llava/mtmd.cpp       |   6 +-
10
11
 ggml/src/ggml-backend-reg.cpp |   6 +-
 include/llama.h               |   6 +
12
 src/llama-arch.cpp            |  44 +++++
13
14
 src/llama-arch.h              |  10 ++
 src/llama-batch.cpp           |   3 +
15
16
 src/llama-context.cpp         |  25 ++-
 src/llama-context.h           |   1 +
17
 src/llama-cparams.h           |   1 +
18
19
20
21
22
 src/llama-graph.cpp           |  25 +++
 src/llama-graph.h             |  12 ++
 src/llama-hparams.cpp         |   4 +
 src/llama-hparams.h           |   7 +
 src/llama-kv-cache.cpp        |  12 +-
23
 src/llama-model-loader.cpp    |   2 +
24
 src/llama-model.cpp           | 309 +++++++++++++++++++++++++++++++++-
25
 src/llama-model.h             |  12 ++
26
 src/llama-quant.cpp           |   4 +-
27
 19 files changed, 473 insertions(+), 21 deletions(-)
28

29
diff --git a/examples/llava/llava.cpp b/examples/llava/llava.cpp
30
index c00d16ae..bab027b5 100644
31
32
--- a/examples/llava/llava.cpp
+++ b/examples/llava/llava.cpp
33
@@ -457,7 +457,7 @@ struct llava_embd_batch {
34
35
36
37
38
39
40
41
     std::vector<llama_seq_id *> seq_ids;
     std::vector<int8_t>         logits;
     llama_batch batch;
-    llava_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
+    llava_embd_batch(float * embd, int32_t n_embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
         pos     .resize(n_tokens);
         n_seq_id.resize(n_tokens);
         seq_ids .resize(n_tokens + 1);
42
@@ -469,6 +469,7 @@ struct llava_embd_batch {
43
44
45
46
47
48
49
             /*n_tokens       =*/ n_tokens,
             /*tokens         =*/ nullptr,
             /*embd           =*/ embd,
+            /*n_embd         =*/ n_embd,
             /*pos            =*/ pos.data(),
             /*n_seq_id       =*/ n_seq_id.data(),
             /*seq_id         =*/ seq_ids.data(),
50
@@ -492,7 +493,7 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_
51
52
             n_eval = n_batch;
         }
53
54
55
56
         float * embd = image_embed->embed+i*n_embd;
-        llava_embd_batch llava_batch = llava_embd_batch(embd, n_eval, *n_past, 0);
+        llava_embd_batch llava_batch = llava_embd_batch(embd, n_embd, n_eval, *n_past, 0);
         if (llama_decode(ctx_llama, llava_batch.batch)) {
57
58
             LOG_ERR("%s : failed to eval\n", __func__);
             return false;
59
diff --git a/examples/llava/mtmd.cpp b/examples/llava/mtmd.cpp
60
index 7081fd73..c14ac501 100644
61
62
--- a/examples/llava/mtmd.cpp
+++ b/examples/llava/mtmd.cpp
63
@@ -476,7 +476,7 @@ struct decode_embd_batch {
64
65
66
     std::vector<llama_seq_id *> seq_ids;
     std::vector<int8_t>         logits;
     llama_batch batch;
67
68
69
-    decode_embd_batch(float * embd, int32_t n_tokens, int n_pos_per_embd, int n_mmproj_embd) : n_pos_per_embd(n_pos_per_embd), n_mmproj_embd(n_mmproj_embd) {
+    decode_embd_batch(float * embd, int32_t n_embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) : n_pos_per_embd(n_pos_per_embd), n_mmproj_embd(n_mmproj_embd) {
         pos     .resize(n_tokens * n_pos_per_embd);
70
71
         n_seq_id.resize(n_tokens);
         seq_ids .resize(n_tokens + 1);
72
@@ -487,6 +487,7 @@ struct decode_embd_batch {
73
74
75
76
77
78
79
             /*n_tokens       =*/ n_tokens,
             /*tokens         =*/ nullptr,
             /*embd           =*/ embd,
+            /*n_embd         =*/ n_embd,
             /*pos            =*/ pos.data(),
             /*n_seq_id       =*/ n_seq_id.data(),
             /*seq_id         =*/ seq_ids.data(),
80
81
82
@@ -610,7 +611,8 @@ int32_t mtmd_helper_eval(mtmd_context * ctx,
             int32_t i_batch = 0;
             int32_t n_img_batches = GGML_PAD(n_tokens, n_batch) / n_batch;
83
             float * embd = mtmd_get_output_embd(ctx);
84
-            decode_embd_batch batch_embd(embd, n_tokens, n_pos_per_embd, n_mmproj_embd);
85
+            int n_embd  = llama_model_n_embd(llama_get_model(lctx));
86
87
88
89
+            decode_embd_batch batch_embd(embd, n_embd, n_tokens, n_past, 0);
 
             const int nx = mtmd_image_tokens_get_nx(chunk.tokens_image.get());
             const int ny = mtmd_image_tokens_get_ny(chunk.tokens_image.get());
90
diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp
91
index 405d8e31..82ae1b5b 100644
92
93
--- a/ggml/src/ggml-backend-reg.cpp
+++ b/ggml/src/ggml-backend-reg.cpp
94
@@ -178,9 +178,9 @@ struct ggml_backend_registry {
95
96
97
98
99
100
101
102
103
104
105
106
 #ifdef GGML_USE_CANN
         register_backend(ggml_backend_cann_reg());
 #endif
-#ifdef GGML_USE_BLAS
-        register_backend(ggml_backend_blas_reg());
-#endif
+// #ifdef GGML_USE_BLAS
+//         register_backend(ggml_backend_blas_reg());
+// #endif
 #ifdef GGML_USE_RPC
         register_backend(ggml_backend_rpc_reg());
 #endif
107
diff --git a/include/llama.h b/include/llama.h
108
index 06c56395..f1628e88 100644
109
110
--- a/include/llama.h
+++ b/include/llama.h
111
@@ -256,6 +256,7 @@ extern "C" {
112
113
114
115
116
117
118
 
         llama_token  *  token;
         float        *  embd;
+        int32_t         n_embd;
         llama_pos    *  pos;
         int32_t      *  n_seq_id;
         llama_seq_id ** seq_id;
119
@@ -358,6 +359,7 @@ extern "C" {
120
121
122
123
124
125
126
         bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
         bool flash_attn;  // whether to use flash attention [EXPERIMENTAL]
         bool no_perf;     // whether to measure performance timings
+        bool cross_attn;  // whether to use cross attention
 
         // Abort callback
         // if it returns true, execution of llama_decode() will be aborted
127
@@ -459,6 +461,10 @@ extern "C" {
128
129
             struct llama_context_params   params),
             "use llama_init_from_model instead");
130
131
132
 
+    // TODO (jmorganca): this should most likely be passed in as part of a batch
+    // and not set on the context for all batches.
133
+    LLAMA_API void llama_set_cross_attention(struct llama_context * ctx, bool cross_attn_state);
134
135
136
137
+
     // Frees all allocated memory
     LLAMA_API void llama_free(struct llama_context * ctx);
 
138
diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp
139
index 5ab3f572..eb7b5325 100644
140
141
142
--- a/src/llama-arch.cpp
+++ b/src/llama-arch.cpp
@@ -6,6 +6,7 @@
143
144
 
 static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
145
146
     { LLM_ARCH_LLAMA,            "llama"            },
+    { LLM_ARCH_MLLAMA,           "mllama"           },
147
     { LLM_ARCH_LLAMA4,           "llama4"           },
148
149
     { LLM_ARCH_DECI,             "deci"             },
     { LLM_ARCH_FALCON,           "falcon"           },
150
@@ -144,6 +145,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
151
152
153
154
     { LLM_KV_ATTENTION_SLIDING_WINDOW,               "%s.attention.sliding_window"               },
     { LLM_KV_ATTENTION_SCALE,                        "%s.attention.scale"                        },
     { LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION,        "%s.attention.block_skip_connection"        },
+    { LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS,       "%s.attention.cross_attention_layers"       },
155
156
     { LLM_KV_ATTENTION_KEY_LENGTH_MLA,               "%s.attention.key_length_mla"               },
     { LLM_KV_ATTENTION_VALUE_LENGTH_MLA,             "%s.attention.value_length_mla"             },
157
 
158
@@ -273,6 +275,40 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
159
             { LLM_TENSOR_FFN_UP_SHEXP,    "blk.%d.ffn_up_shexp" },
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
         },
     },
+    {
+        LLM_ARCH_MLLAMA,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
+            { LLM_TENSOR_OUTPUT,          "output" },
+            { LLM_TENSOR_ROPE_FREQS,      "rope_freqs" },
+            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
+            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
+            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
+            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
+            { LLM_TENSOR_ATTN_ROT_EMBD,   "blk.%d.attn_rot_embd" },
+            { LLM_TENSOR_FFN_GATE_INP,    "blk.%d.ffn_gate_inp" },
+            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
+            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
+            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
+            { LLM_TENSOR_FFN_GATE_EXP,    "blk.%d.ffn_gate.%d" },
+            { LLM_TENSOR_FFN_DOWN_EXP,    "blk.%d.ffn_down.%d" },
+            { LLM_TENSOR_FFN_UP_EXP,      "blk.%d.ffn_up.%d" },
+            { LLM_TENSOR_FFN_GATE_EXPS,   "blk.%d.ffn_gate_exps" },
+            { LLM_TENSOR_FFN_DOWN_EXPS,   "blk.%d.ffn_down_exps" },
+            { LLM_TENSOR_FFN_UP_EXPS,     "blk.%d.ffn_up_exps" },
+            { LLM_TENSOR_CROSS_ATTN_K_NORM,    "blk.%d.cross_attn_k_norm" },
+            { LLM_TENSOR_CROSS_ATTN_K_PROJ,    "blk.%d.cross_attn_k_proj" },
+            { LLM_TENSOR_CROSS_ATTN_O_PROJ,    "blk.%d.cross_attn_o_proj" },
+            { LLM_TENSOR_CROSS_ATTN_Q_NORM,    "blk.%d.cross_attn_q_norm" },
+            { LLM_TENSOR_CROSS_ATTN_Q_PROJ,    "blk.%d.cross_attn_q_proj" },
+            { LLM_TENSOR_CROSS_ATTN_V_PROJ,    "blk.%d.cross_attn_v_proj" },
+            { LLM_TENSOR_CROSS_ATTN_ATTN_GATE, "blk.%d.cross_attn_attn_gate" },
+            { LLM_TENSOR_CROSS_ATTN_MLP_GATE,  "blk.%d.cross_attn_mlp_gate" },
+        },
+    },
     {
197
         LLM_ARCH_DECI,
198
         {
199
@@ -1701,6 +1737,14 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
     // this tensor is loaded for T5, but never used
     {LLM_TENSOR_DEC_CROSS_ATTN_REL_B,       {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}},
     {LLM_TENSOR_BSKCN_TV,                   {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
+    {LLM_TENSOR_CROSS_ATTN_K_NORM,          {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
+    {LLM_TENSOR_CROSS_ATTN_K_PROJ,          {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_CROSS_ATTN_O_PROJ,          {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_CROSS_ATTN_Q_NORM,          {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
+    {LLM_TENSOR_CROSS_ATTN_Q_PROJ,          {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_CROSS_ATTN_V_PROJ,          {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+    {LLM_TENSOR_CROSS_ATTN_ATTN_GATE,       {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
+    {LLM_TENSOR_CROSS_ATTN_MLP_GATE,        {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
     {LLM_TENSOR_CONV1D,                     {LLM_TENSOR_LAYER_INPUT,     GGML_OP_IM2COL}},
     {LLM_TENSOR_POS_NET_NORM,               {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
     {LLM_TENSOR_POS_NET_NORM1,              {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
diff --git a/src/llama-arch.h b/src/llama-arch.h
215
index 525c1b7d..bc8a4f0b 100644
216
217
--- a/src/llama-arch.h
+++ b/src/llama-arch.h
218
@@ -11,6 +11,7 @@
219
220
 enum llm_arch {
     LLM_ARCH_LLAMA,
221
     LLM_ARCH_LLAMA4,
222
223
224
225
+    LLM_ARCH_MLLAMA,
     LLM_ARCH_DECI,
     LLM_ARCH_FALCON,
     LLM_ARCH_BAICHUAN,
226
@@ -148,6 +149,7 @@ enum llm_kv {
227
228
229
230
     LLM_KV_ATTENTION_SLIDING_WINDOW,
     LLM_KV_ATTENTION_SCALE,
     LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION,
+    LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS,
231
232
     LLM_KV_ATTENTION_KEY_LENGTH_MLA,
     LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
233
 
234
@@ -349,6 +351,14 @@ enum llm_tensor {
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
     LLM_TENSOR_CLS,
     LLM_TENSOR_CLS_OUT,
     LLM_TENSOR_BSKCN_TV,
+    LLM_TENSOR_CROSS_ATTN_K_NORM,
+    LLM_TENSOR_CROSS_ATTN_K_PROJ,
+    LLM_TENSOR_CROSS_ATTN_O_PROJ,
+    LLM_TENSOR_CROSS_ATTN_Q_NORM,
+    LLM_TENSOR_CROSS_ATTN_Q_PROJ,
+    LLM_TENSOR_CROSS_ATTN_V_PROJ,
+    LLM_TENSOR_CROSS_ATTN_ATTN_GATE,
+    LLM_TENSOR_CROSS_ATTN_MLP_GATE,
     LLM_TENSOR_CONV1D,
     LLM_TENSOR_CONVNEXT_DW,
     LLM_TENSOR_CONVNEXT_NORM,
diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp
index 01d5ca57..8682b0e6 100644
--- a/src/llama-batch.cpp
+++ b/src/llama-batch.cpp
@@ -316,6 +316,7 @@ struct llama_batch llama_batch_get_one(
         /*n_tokens       =*/ n_tokens,
         /*tokens         =*/ tokens,
         /*embd           =*/ nullptr,
+        /*n_embd         =*/ 0,
         /*pos            =*/ nullptr,
         /*n_seq_id       =*/ nullptr,
         /*seq_id         =*/ nullptr,
@@ -328,6 +329,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
         /*n_tokens       =*/ 0,
         /*tokens         =*/ nullptr,
         /*embd           =*/ nullptr,
+        /*n_embd         =*/ 0,
         /*pos            =*/ nullptr,
         /*n_seq_id       =*/ nullptr,
         /*seq_id         =*/ nullptr,
@@ -336,6 +338,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
 
     if (embd) {
         batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
+        batch.n_embd = embd;
     } else {
         batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
     }
diff --git a/src/llama-context.cpp b/src/llama-context.cpp
278
index 9c1fe93f..cd06ad91 100644
279
280
--- a/src/llama-context.cpp
+++ b/src/llama-context.cpp
281
@@ -851,7 +851,7 @@ float * llama_context::get_logits_ith(int32_t i) {
282
283
             throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, n_outputs));
         }
284
 
285
286
287
288
289
-        return logits + j*model.vocab.n_tokens();
+        return logits + j*model.hparams.n_vocab;
     } catch (const std::exception & err) {
         LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what());
 #ifndef NDEBUG
290
@@ -972,6 +972,10 @@ void llama_context::set_warmup(bool value) {
291
292
     cparams.warmup = value;
 }
293
 
294
295
296
297
298
299
300
+void llama_context::set_cross_attn(bool value) {
+    cparams.cross_attn = value;
+}
+
 void llama_context::set_adapter_lora(
             llama_adapter_lora * adapter,
             float scale) {
301
@@ -1047,7 +1051,7 @@ int llama_context::encode(llama_batch & inp_batch) {
302
 
303
304
305
306
307
308
309
     const int64_t n_embd = hparams.n_embd;
 
-    sbatch.from_batch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
+    sbatch.from_batch(batch, batch.n_embd, /* simple_split */ true, /* logits_all */ true);
 
     const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
 
310
@@ -1187,10 +1191,9 @@ int llama_context::decode(llama_batch & inp_batch) {
311
 
312
313
314
315
316
317
318
319
320
321
     const llama_batch & batch = batch_allocr.batch;
 
-    const auto & vocab   = model.vocab;
     const auto & hparams = model.hparams;
 
-    const int32_t n_vocab = vocab.n_tokens();
+    const int32_t n_vocab = hparams.n_vocab;
 
     const int64_t n_tokens_all = batch.n_tokens;
     const int64_t n_embd       = hparams.n_embd;
322
@@ -1238,7 +1241,7 @@ int llama_context::decode(llama_batch & inp_batch) {
323
324
325
326
327
328
329
330
 
     const bool logits_all = n_outputs_all == n_tokens_all;
 
-    sbatch.from_batch(batch, n_embd,
+    sbatch.from_batch(batch, batch.n_embd,
             /* simple_split */ !kv_self->recurrent,
             /* logits_all   */ logits_all);
 
331
@@ -1472,12 +1475,11 @@ int llama_context::decode(llama_batch & inp_batch) {
332
333
334
335
336
337
 
 int32_t llama_context::output_reserve(int32_t n_outputs) {
     const auto & hparams = model.hparams;
-    const auto & vocab   = model.vocab;
 
     const int64_t n_outputs_max = std::max<int64_t>(n_outputs, n_seq_max());
338
339
340
341
342
343
344
 
     const auto n_batch = cparams.n_batch;
-    const auto n_vocab = vocab.n_tokens();
+    const auto n_vocab = hparams.n_vocab;
     const auto n_embd  = hparams.n_embd;
 
     // TODO: use a per-batch flag for logits presence instead
345
@@ -1545,7 +1547,7 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
346
347
 void llama_context::output_reorder() {
     auto & out_ids = sbatch.out_ids;
348
     if (!out_ids.empty()) {
349
350
351
352
353
-        const uint32_t n_vocab = model.vocab.n_tokens();
+        const uint32_t n_vocab = model.hparams.n_vocab;
         const uint32_t n_embd  = model.hparams.n_embd;
 
         GGML_ASSERT((size_t) n_outputs == out_ids.size());
354
@@ -2052,7 +2054,7 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
355
356
357
358
359
     {
         LLAMA_LOG_DEBUG("%s: - writing logits\n", __func__);
 
-        const uint64_t logits_size = std::min((uint64_t) this->logits_size, (uint64_t) n_outputs * model.vocab.n_tokens());
+        const uint64_t logits_size = std::min((uint64_t) this->logits_size, (uint64_t) n_outputs * model.hparams.n_vocab);
360
 
361
362
         io.write(&logits_size, sizeof(logits_size));
 
363
@@ -2235,6 +2237,7 @@ llama_context_params llama_context_default_params() {
364
365
366
367
368
369
370
         /*.offload_kqv                 =*/ true,
         /*.flash_attn                  =*/ false,
         /*.no_perf                     =*/ true,
+        /*.cross_attn                  =*/ false,
         /*.abort_callback              =*/ nullptr,
         /*.abort_callback_data         =*/ nullptr,
     };
371
@@ -2362,6 +2365,10 @@ void llama_set_warmup(llama_context * ctx, bool warmup) {
372
     ctx->set_warmup(warmup);
373
374
375
 }
 
+void llama_set_cross_attention(struct llama_context * ctx, bool cross_attention) {
376
+    ctx->set_cross_attn(cross_attention);
377
+}
378
+
379
380
381
 void llama_synchronize(llama_context * ctx) {
     ctx->synchronize();
 }
382
diff --git a/src/llama-context.h b/src/llama-context.h
383
index 5457f077..a50c4afa 100644
384
385
--- a/src/llama-context.h
+++ b/src/llama-context.h
386
387
388
389
390
@@ -65,6 +65,7 @@ struct llama_context {
     void set_embeddings (bool value);
     void set_causal_attn(bool value);
     void set_warmup(bool value);
+    void set_cross_attn(bool value);
391
 
392
393
     void set_adapter_lora(
             llama_adapter_lora * adapter,
394
diff --git a/src/llama-cparams.h b/src/llama-cparams.h
395
index 30e550f0..85ad91b9 100644
396
397
398
--- a/src/llama-cparams.h
+++ b/src/llama-cparams.h
@@ -29,6 +29,7 @@ struct llama_cparams {
399
400
401
     bool offload_kqv;
     bool flash_attn;
     bool no_perf;
402
+    bool cross_attn;
403
     bool warmup;
404
405
 
     enum llama_pooling_type pooling_type;
406
diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp
407
index fabb9ca2..b67216a4 100644
408
409
--- a/src/llama-graph.cpp
+++ b/src/llama-graph.cpp
410
@@ -560,6 +560,12 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
411
412
413
414
415
416
417
418
419
420
421
422
     }
 }
 
+void llm_graph_input_cross_attn_state::set_input(const llama_ubatch * ubatch) {
+    if (ubatch->embd) {
+        ggml_backend_tensor_set(cross_attn_state, ubatch->embd, 0, ggml_nbytes(cross_attn_state));
+    }
+}
+
 //
 // llm_graph_context
 //
423
@@ -1532,6 +1538,25 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
424
425
     return (llm_graph_input_attn_cross *) res->add_input(std::move(inp));
 }
426
 
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
+ggml_tensor * llm_graph_context::build_inp_cross_attn_state() const {
+    const int64_t n_embd = hparams.n_embd;
+
+    auto inp = std::make_unique<llm_graph_input_cross_attn_state>();
+
+    ggml_tensor * cur = nullptr;
+
+    inp->cross_attn_state = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd, 1601, 4);
+    ggml_set_input(inp->cross_attn_state);
+
+    cur = inp->cross_attn_state;
+
+    cb(cur, "inp_cross_attn_state", -1);
+
+    res->add_input(std::move(inp));
+
+    return cur;
+}
+
 ggml_tensor * llm_graph_context::build_attn(
         llm_graph_input_attn_cross * inp,
         ggml_cgraph * gf,
diff --git a/src/llama-graph.h b/src/llama-graph.h
450
index d0c8d321..0fe18150 100644
451
452
453
454
455
456
457
458
459
460
--- a/src/llama-graph.h
+++ b/src/llama-graph.h
@@ -86,6 +86,7 @@ public:
 
     ggml_tensor * tokens = nullptr; // I32 [n_batch]
     ggml_tensor * embd   = nullptr; // F32 [n_embd, n_batch]
+    ggml_tensor * cross_attn_state; // F32 [4, n_embd, 1061]
 };
 
 class llm_graph_input_pos : public llm_graph_input_i {
461
@@ -283,6 +284,16 @@ public:
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
     const llama_cross * cross = nullptr;
 };
 
+class llm_graph_input_cross_attn_state : public llm_graph_input_i {
+public:
+    llm_graph_input_cross_attn_state()          = default;
+    virtual ~llm_graph_input_cross_attn_state() = default;
+
+    void set_input(const llama_ubatch * ubatch) override;
+
+    ggml_tensor * cross_attn_state; // F32 [4, n_embd, 1061]
+};
+
 //
 // llm_graph_result
 //
478
@@ -491,6 +502,7 @@ struct llm_graph_context {
479
480
481
482
483
484
485
     ggml_tensor * build_inp_cls() const;
     ggml_tensor * build_inp_s_copy() const;
     ggml_tensor * build_inp_s_mask() const;
+    ggml_tensor * build_inp_cross_attn_state() const;
 
     ggml_tensor * build_inp_cross_embd() const;
     ggml_tensor * build_inp_pos_bucket_enc() const;
486
diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp
487
index 8a667960..6a02de03 100644
488
489
--- a/src/llama-hparams.cpp
+++ b/src/llama-hparams.cpp
490
@@ -85,3 +85,7 @@ bool llama_hparams::is_swa(uint32_t il) const {
491
 
492
     GGML_ABORT("fatal error");
493
 }
494
+
495
496
+bool llama_hparams::cross_attention_layers(uint32_t il) const {
+    return std::find(cross_attn_layers.begin(), cross_attn_layers.end(), il) != cross_attn_layers.end();
497
+}
498
diff --git a/src/llama-hparams.h b/src/llama-hparams.h
499
index 48dce407..b6fc7e6d 100644
500
501
--- a/src/llama-hparams.h
+++ b/src/llama-hparams.h
502
503
504
505
506
507
508
509
510
511
@@ -2,6 +2,8 @@
 
 #include "llama.h"
 
+#include <algorithm>
+
 #include <array>
 
 // bump if necessary
@@ -42,6 +44,7 @@ struct llama_hparams {
512
513
514
515
516
     uint32_t n_expert = 0;
     uint32_t n_expert_used = 0;
     uint32_t n_rel_attn_bkts = 0;
+    uint32_t n_vocab = 0;
 
517
518
519
     // note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA
     uint32_t n_embd_head_k_mla = 0;
@@ -56,6 +59,7 @@ struct llama_hparams {
520
521
522
523
524
525
526
     std::array<uint32_t, LLAMA_MAX_LAYERS> n_ff_arr;
 
     std::array<std::array<uint32_t, LLAMA_MAX_LAYERS>, 4> n_bskcn_arr = {};
+    std::array<uint32_t, LLAMA_MAX_LAYERS> cross_attn_layers;
 
     uint32_t n_layer_dense_lead = 0;
     uint32_t n_lora_q           = 0;
527
@@ -159,6 +163,9 @@ struct llama_hparams {
528
529
     // Block skip connection
     bool n_bskcn(uint32_t n, uint32_t il) const;
530
 
531
+    // cross attention layers
532
+    bool cross_attention_layers(uint32_t il) const;
533
534
+
     bool is_swa(uint32_t il) const;
535
536
 };
 
537
diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp
538
index 7c9d46d8..69f8d35a 100644
539
540
--- a/src/llama-kv-cache.cpp
+++ b/src/llama-kv-cache.cpp
541
@@ -95,8 +95,16 @@ bool llama_kv_cache_unified::init(
542
543
             return false;
         }
544
 
545
546
547
548
-        ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
-        ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
+        ggml_tensor * k, *v;
+
549
+        // for cross attention layers
550
+        if (model.arch == LLM_ARCH_MLLAMA && hparams.cross_attention_layers(i)) {
551
552
553
554
555
+            k = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_k, 6404, hparams.n_head_kv(i));
+            v = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_v, 6404, hparams.n_head_kv(i));
+        } else {
+            k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
+            v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
556
+        }
557
558
         ggml_format_name(k, "cache_k_l%d", i);
         ggml_format_name(v, "cache_v_l%d", i);
559
         k_l.push_back(k);
560
diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp
561
index a012aeae..2e11507d 100644
562
563
--- a/src/llama-model-loader.cpp
+++ b/src/llama-model-loader.cpp
564
@@ -315,6 +315,8 @@ namespace GGUFMeta {
565
         return true;
566
567
     }
 
568
569
570
571
572
573
+    template bool llama_model_loader::get_arr<std::array<unsigned int, 512>>(enum llm_kv kid, std::array<unsigned int, 512>& result, bool required);
+
     template<typename T, size_t N_MAX>
     bool llama_model_loader::get_arr(const std::string & key, std::array<T, N_MAX> & result, bool required) {
         const int kid = gguf_find_key(meta.get(), key.c_str());
diff --git a/src/llama-model.cpp b/src/llama-model.cpp
574
index 572378c9..9d099f11 100644
575
576
--- a/src/llama-model.cpp
+++ b/src/llama-model.cpp
577
@@ -423,6 +423,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
578
579
580
581
582
583
584
 
     // get general kv
     ml.get_key(LLM_KV_GENERAL_NAME, name, false);
+    ml.get_key(LLM_KV_VOCAB_SIZE, hparams.n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, hparams.n_vocab, false);
 
     // everything past this point is not vocab-related
     if (hparams.vocab_only) {
585
@@ -434,6 +435,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
586
587
588
589
590
591
592
     ml.get_key(LLM_KV_BLOCK_COUNT,       hparams.n_layer);
     ml.get_key(LLM_KV_EXPERT_COUNT,      hparams.n_expert,      false);
     ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false);
+    ml.get_key(LLM_KV_VOCAB_SIZE,        hparams.n_vocab,       false);
 
     if (arch == LLM_ARCH_WAVTOKENIZER_DEC) {
         ml.get_key(LLM_KV_FEATURES_LENGTH, hparams.n_embd_features);
593
@@ -457,9 +459,11 @@ void llama_model::load_hparams(llama_model_loader & ml) {
594
595
596
     std::fill(hparams.n_head_arr.begin(),    hparams.n_head_arr.end(),    0);
     std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0);
     std::fill(hparams.n_ff_arr.begin(),      hparams.n_ff_arr.end(),      0);
597
+    std::fill(hparams.cross_attn_layers.begin(), hparams.cross_attn_layers.end(), -1);
598
 
599
600
     ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH,  hparams.n_ff_arr,   hparams.n_layer, false);
     ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer, false);
601
602
603
604
+    ml.get_arr(LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS, hparams.cross_attn_layers, false);
 
     // n_head_kv is optional, default to n_head
     hparams.n_head_kv_arr = hparams.n_head_arr;
605
@@ -512,7 +516,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
606
607
608
 
         ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
 
609
610
-        if (arch == LLM_ARCH_LLAMA || arch == LLM_ARCH_DECI || arch == LLM_ARCH_FALCON) {
+        if (arch == LLM_ARCH_LLAMA || arch == LLM_ARCH_MLLAMA || arch == LLM_ARCH_DECI || arch == LLM_ARCH_FALCON) {
611
612
613
             if (hparams.n_rot != hparams.n_embd_head_k) {
                 throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k));
             }
614
@@ -575,6 +579,16 @@ void llama_model::load_hparams(llama_model_loader & ml) {
615
                     hparams.use_kq_norm = false;
616
617
618
619
620
621
622
                 }
             } break;
+        case LLM_ARCH_MLLAMA:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+
+                switch (hparams.n_layer) {
623
624
625
+                    case 40: type = LLM_TYPE_11B; break;
+                    case 100: type = LLM_TYPE_90B; break;
+                    default: type = LLM_TYPE_UNKNOWN;
626
627
+                }
+            } break;
628
         case LLM_ARCH_DECI:
629
630
             {
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
631
@@ -1562,7 +1576,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
632
633
634
635
636
637
638
639
         const int64_t n_embd_head_v = hparams.n_embd_head_v;
         const int64_t n_ff          = hparams.n_ff();
         const int64_t n_embd_gqa    = n_embd_v_gqa;
-        const int64_t n_vocab       = vocab.n_tokens();
+        const int64_t n_vocab       = hparams.n_vocab;
         const int64_t n_token_types = vocab.n_token_types();
         const int64_t n_rot         = hparams.n_rot;
         const int64_t n_expert      = hparams.n_expert;
640
@@ -1815,6 +1829,52 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
641
                         }
642
643
644
645
                     }
                 } break;
+            case LLM_ARCH_MLLAMA:
+                {
646
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab+8}, 0);
647
648
649
+
+                    // output
+                    {
650
651
+                        output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                        output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
652
653
+
+                        // if output is NULL, init from the input tok embed
654
655
+                        if (output == NULL) {
+                            output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
656
657
658
659
+                        }
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
660
+                        auto & layer = layers[i];
661
+
662
+                        if (hparams.cross_attention_layers(i)) {
663
664
665
666
667
668
669
670
671
672
673
674
675
+                            layer.cross_attn_k_norm = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_K_NORM,   "weight", i), {128}, 0);
+                            layer.cross_attn_k_proj = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_K_PROJ,   "weight", i), {n_embd, 1024}, 0);
+                            layer.cross_attn_o_proj = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_O_PROJ,   "weight", i), {n_embd, n_embd}, 0);
+                            layer.cross_attn_q_norm = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_Q_NORM, "weight", i), {128}, 0);
+                            layer.cross_attn_q_proj = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_Q_PROJ, "weight", i), {n_embd, n_embd}, 0);
+                            layer.cross_attn_v_proj = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_V_PROJ, "weight", i), {n_embd, 1024}, 0);
+                            layer.cross_attn_attn_gate = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_ATTN_GATE, i), {1}, 0);
+                            layer.cross_attn_mlp_gate = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_MLP_GATE, i), {1}, 0);
+                            layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+                            layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
+                            layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                            layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                            layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
676
+                        } else {
677
678
679
680
681
682
683
684
685
686
+                            layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+                            layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
+                            layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
+                            layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
+                            layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
+                            layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+                            layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
+                            layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                            layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                            layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
687
688
689
+                        }
+                    }
+                } break;
690
             case LLM_ARCH_DECI:
691
                 {
692
                     tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
693
@@ -4707,6 +4767,246 @@ struct llm_build_llama : public llm_graph_context {
694
     }
695
 };
696
 
697
698
+struct llm_build_mllama: public llm_graph_context {
+    llm_build_mllama(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
699
700
701
702
703
704
705
+        // mutable variable, needed during the last layer of the computation to skip unused tokens
+        int32_t n_tokens = this->n_tokens;
+
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
706
707
708
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+        ggml_tensor * inpCAS;
709
+
710
711
+        inpL = build_inp_embd(model.tok_embd);
+        inpCAS = build_inp_cross_attn_state();
712
+
713
714
+          // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
715
+
716
717
+        auto * inp_attn = build_attn_inp_kv_unified();
+        const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
718
719
+
+        for (int il = 0; il < n_layer; ++il) {
720
+            ggml_tensor * inpSA = inpL;
721
722
+
+            // norm
723
+            cur = build_norm(inpL,
724
+                    model.layers[il].attn_norm, NULL,
725
+                    LLM_NORM_RMS, il);
726
727
+            cb(cur, "attn_norm", il);
+
728
+            if (hparams.cross_attention_layers(il)) {
729
+                if (!ubatch.embd && !cparams.cross_attn) {
730
731
732
733
+                    continue;
+                }
+
+                // cross attention layer
734
+                ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_q_proj, cur);
735
736
737
738
739
+                cb(Qcur, "Qcur", il);
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+                cb(Qcur, "Qcur", il);
+
740
+                Qcur = ggml_cont(ctx0, ggml_permute(ctx0, Qcur, 0, 2, 1, 3));
741
742
+                cb(Qcur, "Qcur", il);
+
743
+                Qcur = build_norm(Qcur, model.layers[il].cross_attn_q_norm, NULL, LLM_NORM_RMS, il);
744
745
+                cb(Qcur, "Qcur", il);
+
746
+                ggml_tensor * Kcur, * Vcur;
747
+                if (ubatch.embd) {
748
749
750
751
752
753
+                    Kcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_k_proj, inpCAS);
+                    cb(Kcur, "Kcur", il);
+
+                    Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, 6404);
+                    cb(Kcur, "Kcur", il);
+
754
+                    Kcur = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
755
756
+                    cb(Kcur, "Kcur", il);
+
757
+                    Kcur = build_norm(Kcur, model.layers[il].cross_attn_k_norm, NULL, LLM_NORM_RMS, il);
758
759
+                    cb(Kcur, "Kcur", il);
+
760
+                    ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, kv_self->k_l[il]));
761
762
763
764
765
766
767
768
769
770
+
+                    Vcur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_v_proj, inpCAS);
+                    cb(Vcur, "Vcur", il);
+
+                    Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, 6404);
+                    cb(Vcur, "Vcur", il);
+
+                    Vcur = ggml_permute(ctx0, Vcur, 0, 2, 1, 3);
+                    cb(Vcur, "Vcur", il);
+
771
+                    ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, kv_self->v_l[il]));
772
+                } else {
773
+                    Kcur = ggml_view_tensor(ctx0, kv_self->k_l[il]);
774
775
+                    cb(Kcur, "Kcur (view)", il);
+
776
+                    Vcur = ggml_view_tensor(ctx0, kv_self->v_l[il]);
777
778
779
780
781
782
783
+                    cb(Vcur, "Vcur (view)", il);
+                }
+
+                struct ggml_tensor * kq = ggml_mul_mat(ctx0, Kcur, Qcur);
+                cb(kq, "kq", il);
+
+                // TODO: apply causal masks
784
+                struct ggml_tensor * kq_soft_max = ggml_soft_max_ext(ctx0, kq, nullptr, 1.f/sqrtf(float(n_embd_head)), hparams.f_max_alibi_bias);
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
+                cb(kq_soft_max, "kq_soft_max", il);
+
+                Vcur = ggml_cont(ctx0, ggml_transpose(ctx0, Vcur));
+                cb(Vcur, "Vcur", il);
+
+                struct ggml_tensor * kqv = ggml_mul_mat(ctx0, Vcur, kq_soft_max);
+                cb(kqv, "kqv", il);
+
+                struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
+                cb(kqv_merged, "kqv_merged", il);
+
+                cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens);
+                cb(cur, "kqv_merged_cont", il);
+
+                cur = ggml_mul_mat(ctx0, model.layers[il].cross_attn_o_proj, cur);
+                cb(cur, "cur", il);
+
+                // TODO: do this in place once?
+                cur = ggml_mul(ctx0, cur, ggml_tanh(ctx0, model.layers[il].cross_attn_attn_gate));
+
+                struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+                cb(ffn_inp, "ffn_inp", il);
+
+                // feed-forward network
809
+                cur = build_norm(ffn_inp,
810
+                        model.layers[il].ffn_norm, NULL,
811
+                        LLM_NORM_RMS, il);
812
813
+                cb(cur, "ffn_norm", il);
+
814
+                cur = build_ffn(cur,
815
816
817
818
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                        model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
+                        NULL,
819
+                        LLM_FFN_SILU, LLM_FFN_PAR, il);
820
821
822
823
824
825
+                cb(cur, "ffn_out", il);
+
+                // TODO: do this inplace once?
+                cur = ggml_add_inplace(ctx0, ggml_mul_inplace(ctx0, cur, ggml_tanh(ctx0, model.layers[il].cross_attn_mlp_gate)), ffn_inp);
+                cb(cur, "ffn_out", il);
+
826
+                cur = build_cvec(cur, il);
827
828
829
830
831
832
833
834
+                cb(cur, "l_out", il);
+
+                // input for next layer
+                inpL = cur;
+            } else {
+                // self attention layer
+
+                // rope freq factors for llama3; may return nullptr for llama2 and other models
835
+                ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
836
837
+
+                // compute Q and K and RoPE them
838
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
839
840
841
842
843
844
+                cb(Qcur, "Qcur", il);
+                if (model.layers[il].bq) {
+                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                    cb(Qcur, "Qcur", il);
+                }
+
845
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
846
847
848
849
850
851
+                cb(Kcur, "Kcur", il);
+                if (model.layers[il].bk) {
+                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                    cb(Kcur, "Kcur", il);
+                }
+
852
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
853
854
855
856
857
858
+                cb(Vcur, "Vcur", il);
+                if (model.layers[il].bv) {
+                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                    cb(Vcur, "Vcur", il);
+                }
+
859
860
861
862
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
863
+                Qcur = ggml_rope_ext(
864
865
866
867
+                        ctx0, Qcur, inp_pos, rope_factors,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
868
869
+
+                Kcur = ggml_rope_ext(
870
871
872
873
+                        ctx0, Kcur, inp_pos, rope_factors,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
874
+
875
+                cb(Qcur, "Qcur", il);
876
+                cb(Kcur, "Kcur", il);
877
+                cb(Vcur, "Vcur", il);
878
+
879
+                cur = build_attn(inp_attn, gf,
880
+                    model.layers[il].wo, model.layers[il].bo,
881
+                    Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
882
883
884
885
886
887
888
889
890
891
892
893
894
+
+                if (il == n_layer - 1) {
+                    // skip computing output for unused tokens
+                    struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+                    n_tokens = n_outputs;
+                    cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                    inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+                }
+
+                struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+                cb(ffn_inp, "ffn_inp", il);
+
+                // feed-forward network
895
+                cur = build_norm(ffn_inp,
896
+                        model.layers[il].ffn_norm, NULL,
897
+                        LLM_NORM_RMS, il);
898
899
+                cb(cur, "ffn_norm", il);
+
900
+                cur = build_ffn(cur,
901
902
903
904
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                        model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
+                        NULL,
905
+                        LLM_FFN_SILU, LLM_FFN_PAR, il);
906
907
908
909
910
+                cb(cur, "ffn_out", il);
+
+                cur = ggml_add(ctx0, cur, ffn_inp);
+                cb(cur, "ffn_out", il);
+
911
+                cur = build_cvec(cur, il);
912
913
914
915
916
917
918
919
920
+                cb(cur, "l_out", il);
+
+                // input for next layer
+                inpL = cur;
+            }
+        }
+
+        cur = inpL;
+
921
+        cur = build_norm(cur,
922
+                model.output_norm, NULL,
923
+                LLM_NORM_RMS, -1);
924
+        cb(cur, "result_norm", -1);
925
+        res->t_embd = cur;
926
+
927
+        // lm_head
928
929
+        cur = build_lora_mm(model.output, cur);
+
930
+        cb(cur, "result_output", -1);
931
+        res->t_logits = cur;
932
933
934
+
+        ggml_build_forward_expand(gf, cur);
+    }
935
+};
936
+
937
938
939
 struct llm_build_deci : public llm_graph_context {
     llm_build_deci(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
940
@@ -13063,6 +13363,10 @@ llm_graph_result_ptr llama_model::build_graph(
941
             {
942
                 llm = std::make_unique<llm_build_llama>(*this, params, gf);
943
944
945
             } break;
+        case LLM_ARCH_MLLAMA:
+            {
946
+                llm = std::make_unique<llm_build_mllama>(*this, params, gf);
947
+            } break;
948
         case LLM_ARCH_DECI:
949
             {
950
                 llm = std::make_unique<llm_build_deci>(*this, params, gf);
951
@@ -13424,6 +13728,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
952
953
954
955
956
957
958
959
         // use what we call a normal RoPE, operating on pairs of consecutive head values
         case LLM_ARCH_LLAMA:
         case LLM_ARCH_LLAMA4:
+        case LLM_ARCH_MLLAMA:
         case LLM_ARCH_DECI:
         case LLM_ARCH_BAICHUAN:
         case LLM_ARCH_STARCODER:
diff --git a/src/llama-model.h b/src/llama-model.h
960
index 856e6042..6be91282 100644
961
962
963
964
965
966
967
--- a/src/llama-model.h
+++ b/src/llama-model.h
@@ -11,6 +11,7 @@
 #include <string>
 #include <unordered_map>
 #include <vector>
+#include <stdexcept>
968
 
969
970
 struct llama_cparams;
 struct llama_ubatch;
971
@@ -73,6 +74,7 @@ enum llm_type {
972
973
974
975
976
     LLM_TYPE_40B,
     LLM_TYPE_65B,
     LLM_TYPE_70B,
+    LLM_TYPE_90B,
     LLM_TYPE_236B,
977
     LLM_TYPE_290B,
978
     LLM_TYPE_314B,
979
@@ -314,6 +316,16 @@ struct llama_layer {
980
 
981
     struct ggml_tensor * bskcn_tv = nullptr;
982
 
983
984
985
986
987
988
989
990
991
992
993
+    // cross attention
+    struct ggml_tensor * cross_attn_k_norm = nullptr;
+    struct ggml_tensor * cross_attn_k_proj = nullptr;
+    struct ggml_tensor * cross_attn_o_proj = nullptr;
+    struct ggml_tensor * cross_attn_q_norm = nullptr;
+    struct ggml_tensor * cross_attn_q_proj = nullptr;
+    struct ggml_tensor * cross_attn_v_proj = nullptr;
+    struct ggml_tensor * cross_attn_attn_gate = nullptr;
+    struct ggml_tensor * cross_attn_mlp_gate = nullptr;
+
     struct llama_layer_posnet posnet;
994
 
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
     struct llama_layer_convnext convnext;
diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp
index 7dc54227..223e1f3f 100644
--- a/src/llama-quant.cpp
+++ b/src/llama-quant.cpp
@@ -639,7 +639,9 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
         if (llama_model_has_encoder(&model)) {
             n_attn_layer *= 3;
         }
-        GGML_ASSERT((qs.n_attention_wv == n_attn_layer) && "n_attention_wv is unexpected");
+        if (qs.n_attention_wv != n_attn_layer) {
+            LLAMA_LOG_WARN("%s: n_attention_wv is unexpected, expected: %d, found: %d\n", __func__, n_attn_layer, qs.n_attention_wv);
+        }
     }
1009
 
1010
     size_t total_size_org = 0;