0005-solar-pro.patch 17.7 KB
Newer Older
1
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
2
3
From: jmorganca <jmorganca@gmail.com>
Date: Tue, 8 Apr 2025 16:03:51 -0700
4
5
Subject: [PATCH] solar-pro

6
adds support for the Solar Pro architecture
7
---
8
 src/llama-arch.cpp         |  21 ++++
9
10
 src/llama-arch.h           |   3 +
 src/llama-hparams.cpp      |   8 ++
11
 src/llama-hparams.h        |   5 +
12
 src/llama-model-loader.cpp |   1 +
13
 src/llama-model.cpp        | 207 +++++++++++++++++++++++++++++++++++++
14
 src/llama-model.h          |   3 +
15
 7 files changed, 248 insertions(+)
16

17
diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp
18
index a6fddc7f..0b0fedcd 100644
19
20
--- a/src/llama-arch.cpp
+++ b/src/llama-arch.cpp
21
@@ -68,6 +68,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
22
23
24
25
26
     { LLM_ARCH_GRANITE,          "granite"          },
     { LLM_ARCH_GRANITE_MOE,      "granitemoe"       },
     { LLM_ARCH_CHAMELEON,        "chameleon"        },
+    { LLM_ARCH_SOLAR,            "solar"            },
     { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
27
28
29
30
31
32
33
     { LLM_ARCH_PLM,              "plm"              },
     { LLM_ARCH_BAILINGMOE,       "bailingmoe"       },
@@ -140,6 +141,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
     { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,       "%s.attention.relative_buckets_count"       },
     { LLM_KV_ATTENTION_SLIDING_WINDOW,               "%s.attention.sliding_window"               },
     { LLM_KV_ATTENTION_SCALE,                        "%s.attention.scale"                        },
+    { LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION,        "%s.attention.block_skip_connection"        },
34
 
35
36
     { LLM_KV_ROPE_DIMENSION_COUNT,      "%s.rope.dimension_count"                 },
     { LLM_KV_ROPE_DIMENSION_SECTIONS,   "%s.rope.dimension_sections"              },
37
@@ -1478,6 +1480,24 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
38
             { LLM_TENSOR_ATTN_K_NORM,     "blk.%d.attn_k_norm" },
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
         },
     },
+    {
+        LLM_ARCH_SOLAR,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
+            { LLM_TENSOR_OUTPUT,          "output" },
+            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
+            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
+            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
+            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
+            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
+            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
+            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
+            { LLM_TENSOR_BSKCN_TV,        "bskcn_tv" },
+        },
+    },
     {
60
         LLM_ARCH_WAVTOKENIZER_DEC,
61
         {
62
@@ -1671,6 +1691,7 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
63
64
65
66
67
68
69
70
     {LLM_TENSOR_FFN_EXP_PROBS_B,            {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
     // this tensor is loaded for T5, but never used
     {LLM_TENSOR_DEC_CROSS_ATTN_REL_B,       {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}},
+    {LLM_TENSOR_BSKCN_TV,                   {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
     {LLM_TENSOR_CONV1D,                     {LLM_TENSOR_LAYER_INPUT,     GGML_OP_IM2COL}},
     {LLM_TENSOR_POS_NET_NORM,               {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
     {LLM_TENSOR_POS_NET_NORM1,              {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
diff --git a/src/llama-arch.h b/src/llama-arch.h
71
index 2c2099b3..74aa3dd0 100644
72
73
--- a/src/llama-arch.h
+++ b/src/llama-arch.h
74
@@ -72,6 +72,7 @@ enum llm_arch {
75
76
77
78
79
     LLM_ARCH_GRANITE,
     LLM_ARCH_GRANITE_MOE,
     LLM_ARCH_CHAMELEON,
+    LLM_ARCH_SOLAR,
     LLM_ARCH_WAVTOKENIZER_DEC,
80
81
82
     LLM_ARCH_PLM,
     LLM_ARCH_BAILINGMOE,
@@ -144,6 +145,7 @@ enum llm_kv {
83
84
85
86
87
88
89
     LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
     LLM_KV_ATTENTION_SLIDING_WINDOW,
     LLM_KV_ATTENTION_SCALE,
+    LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION,
 
     LLM_KV_ROPE_DIMENSION_COUNT,
     LLM_KV_ROPE_DIMENSION_SECTIONS,
90
@@ -340,6 +342,7 @@ enum llm_tensor {
91
92
93
94
95
96
97
98
     LLM_TENSOR_ENC_OUTPUT_NORM,
     LLM_TENSOR_CLS,
     LLM_TENSOR_CLS_OUT,
+    LLM_TENSOR_BSKCN_TV,
     LLM_TENSOR_CONV1D,
     LLM_TENSOR_CONVNEXT_DW,
     LLM_TENSOR_CONVNEXT_NORM,
diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp
99
index 90dfe7a7..8a667960 100644
100
101
--- a/src/llama-hparams.cpp
+++ b/src/llama-hparams.cpp
102
@@ -70,6 +70,14 @@ uint32_t llama_hparams::n_embd_v_s() const {
103
104
     return ssm_d_state * ssm_d_inner;
 }
105
 
106
107
108
109
110
111
112
+bool llama_hparams::n_bskcn(uint32_t n, uint32_t il) const {
+    if (il < n_layer) {
+        return n_bskcn_arr[n][il] > 0;
+    }
+
+    GGML_ABORT("fatal error");
+}
113
114
115
116
+
 bool llama_hparams::is_swa(uint32_t il) const {
     if (il < n_layer) {
         return n_swa > 0 && n_swa_pattern > 0 && il % n_swa_pattern < (n_swa_pattern - 1);
117
diff --git a/src/llama-hparams.h b/src/llama-hparams.h
118
index 4e0b5719..c3147cbc 100644
119
120
--- a/src/llama-hparams.h
+++ b/src/llama-hparams.h
121
@@ -51,6 +51,8 @@ struct llama_hparams {
122
123
124
     std::array<uint32_t, LLAMA_MAX_LAYERS> n_head_kv_arr;
     std::array<uint32_t, LLAMA_MAX_LAYERS> n_ff_arr;
 
125
+    std::array<std::array<uint32_t, LLAMA_MAX_LAYERS>, 4> n_bskcn_arr = {};
126
127
+
     uint32_t n_layer_dense_lead = 0;
128
129
     uint32_t n_lora_q           = 0;
     uint32_t n_lora_kv          = 0;
130
@@ -149,6 +151,9 @@ struct llama_hparams {
131
132
     // dimension of the recurrent state embeddings
     uint32_t n_embd_v_s() const;
133
 
134
135
+    // Block skip connection
+    bool n_bskcn(uint32_t n, uint32_t il) const;
136
137
+
     bool is_swa(uint32_t il) const;
138
139
 };
 
140
diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp
141
index ea73a8a7..a012aeae 100644
142
143
--- a/src/llama-model-loader.cpp
+++ b/src/llama-model-loader.cpp
144
@@ -439,6 +439,7 @@ namespace GGUFMeta {
145
146
147
148
     // TODO: this is not very clever - figure out something better
     template bool llama_model_loader::get_key_or_arr<std::array<int, 4>>(enum llm_kv kid, std::array<int, 4> & result, uint32_t n, bool required);
     template bool llama_model_loader::get_key_or_arr<std::array<uint32_t, 512>>(enum llm_kv kid, std::array<uint32_t, 512> & result, uint32_t n, bool required);
+    template bool llama_model_loader::get_key_or_arr<uint32_t>(const std::string & key, std::array<uint32_t, 512> & result, uint32_t n, bool required);
149
 
150
151
 llama_model_loader::llama_model_loader(
         const std::string & fname,
152
diff --git a/src/llama-model.cpp b/src/llama-model.cpp
153
index b74dd72c..5fbd0055 100644
154
155
--- a/src/llama-model.cpp
+++ b/src/llama-model.cpp
156
@@ -1372,6 +1372,21 @@ void llama_model::load_hparams(llama_model_loader & ml) {
157
                     default: type = LLM_TYPE_UNKNOWN;
158
                }
159
160
161
162
             } break;
+        case LLM_ARCH_SOLAR:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
163
164
+                for (size_t i = 0; i < hparams.n_bskcn_arr.max_size(); ++i) {
+                    auto & bskcn = hparams.n_bskcn_arr[i];
165
+                    bskcn.fill(0);
166
+                    auto kv = LLM_KV(arch);
167
+                    ml.get_key_or_arr(format((kv(LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION) + ".%d").c_str(), i), bskcn, hparams.n_layer, false);
168
169
170
+                }
+
+                switch (hparams.n_layer) {
171
172
+                    case 64: type = LLM_TYPE_22B; break;
+                    default: type = LLM_TYPE_UNKNOWN;
173
+                }
174
175
176
177
+            } break;
         case LLM_ARCH_WAVTOKENIZER_DEC:
             {
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS,    hparams.f_norm_eps);
178
@@ -3701,6 +3716,34 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
179
 
180
                         layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
181
 
182
183
184
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
185
186
+                    }
+                } break;
187
188
+            case LLM_ARCH_SOLAR:
+                {
189
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
190
191
192
+
+                    // output
+                    {
193
194
+                        output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                        output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
195
196
197
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
198
+                        auto & layer = layers[i];
199
+
200
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
201
+
202
203
204
205
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
206
+
207
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
208
+
209
210
211
212
+                        layer.bskcn_tv = create_tensor(tn(LLM_TENSOR_BSKCN_TV, "weight", i), {2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
                         layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
                         layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
                         layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
213
214
215
@@ -12244,6 +12287,165 @@ struct llm_build_chameleon : public llm_graph_context {
     }
 };
216
 
217
218
+struct llm_build_solar : public llm_graph_context {
+    llm_build_solar(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
219
220
221
222
223
224
225
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
+
226
+        inpL = build_inp_embd(model.tok_embd);
227
228
229
230
231
+
+        // inp_pos - contains the positions
+        struct ggml_tensor * inp_pos = build_inp_pos();
+
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
232
233
234
+        auto * inp_attn = build_attn_inp_kv_unified();
+
+        const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
+
+        struct ggml_tensor * bskcn_1;
+        struct ggml_tensor * bskcn_2;
+
+        for (int il = 0; il < n_layer; ++il) {
+            struct ggml_tensor * inpSA = inpL;
+
+            if (hparams.n_bskcn(0, il)) {
+                bskcn_1 = inpSA;
+            }
+
+            if (hparams.n_bskcn(1, il)) {
+                bskcn_2 = inpSA;
+            }
+
+            if (hparams.n_bskcn(2, il)) {
+                inpSA = ggml_add(
+                   ctx0,
+                   ggml_mul(ctx0, bskcn_1, ggml_view_1d(ctx0, model.layers[il].bskcn_tv, 1, 0)),
+                   ggml_mul(ctx0, inpSA, ggml_view_1d(ctx0, model.layers[il].bskcn_tv, 1, ggml_element_size(model.layers[il].bskcn_tv))));
+            }
+
+            if (hparams.n_bskcn(3, il)) {
+                inpSA = ggml_add(
+                   ctx0,
+                   ggml_mul(ctx0, bskcn_2, ggml_view_1d(ctx0, model.layers[il].bskcn_tv, 1, 0)),
+                   ggml_mul(ctx0, inpSA, ggml_view_1d(ctx0, model.layers[il].bskcn_tv, 1, ggml_element_size(model.layers[il].bskcn_tv))));
+            }
263
+
264
+            // norm
265
+            cur = build_norm(inpL,
266
+                    model.layers[il].attn_norm, NULL,
267
+                    LLM_NORM_RMS, il);
268
269
270
271
272
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                // rope freq factors for llama3; may return nullptr for llama2 and other models
273
+                ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
274
275
+
+                // compute Q and K and RoPE them
276
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
277
278
279
280
281
282
+                cb(Qcur, "Qcur", il);
+                if (model.layers[il].bq) {
+                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                    cb(Qcur, "Qcur", il);
+                }
+
283
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
284
285
286
287
288
289
+                cb(Kcur, "Kcur", il);
+                if (model.layers[il].bk) {
+                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                    cb(Kcur, "Kcur", il);
+                }
+
290
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
291
292
293
294
295
296
+                cb(Vcur, "Vcur", il);
+                if (model.layers[il].bv) {
+                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                    cb(Vcur, "Vcur", il);
+                }
+
297
298
299
300
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
301
+                Qcur = ggml_rope_ext(
302
303
304
305
+                        ctx0, Qcur, inp_pos, rope_factors,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
306
307
+
+                Kcur = ggml_rope_ext(
308
309
310
311
312
313
+                        ctx0, Kcur, inp_pos, rope_factors,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                
+                cb(Qcur, "Qcur", il);
314
+                cb(Kcur, "Kcur", il);
315
+                cb(Vcur, "Vcur", il);
316
+
317
+                cur = build_attn(inp_attn, gf,
318
+                        model.layers[il].wo, model.layers[il].bo,
319
320
+                        Qcur, Kcur, Vcur, nullptr, kq_scale, il);
+                cb(cur, "attn_out", il);
321
322
323
324
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
325
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
326
327
328
329
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
330
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
331
332
333
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
334
+            cur = build_norm(ffn_inp,
335
+                    model.layers[il].ffn_norm, NULL,
336
+                    LLM_NORM_RMS, il);
337
338
+            cb(cur, "ffn_norm", il);
+
339
+            cur = build_ffn(cur,
340
341
342
343
+                    model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                    model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
+                    model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
+                    NULL,
344
+                    LLM_FFN_SILU, LLM_FFN_PAR, il);
345
346
347
348
349
+            cb(cur, "ffn_out", il);
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cb(cur, "ffn_out", il);
+
350
+            cur = build_cvec(cur, il);
351
352
353
354
355
356
357
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
358
359
+
+        cur = build_norm(cur,
360
+                model.output_norm, NULL,
361
362
+                LLM_NORM_RMS, -1);
+
363
+        cb(cur, "result_norm", -1);
364
365
+        res->t_embd = cur;
+
366
+        // lm_head
367
368
+        cur = build_lora_mm(model.output, cur);
+
369
+        cb(cur, "result_output", -1);
370
371
+        res->t_logits = cur;
+
372
+        ggml_build_forward_expand(gf, cur);
373
374
375
376
377
378
379
+    }
+};
+
 struct llm_build_wavtokenizer_dec : public llm_graph_context {
     llm_build_wavtokenizer_dec(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
         ggml_tensor * cur;
@@ -12993,6 +13195,10 @@ llm_graph_result_ptr llama_model::build_graph(
380
             {
381
                 llm = std::make_unique<llm_build_chameleon>(*this, params, gf);
382
383
384
             } break;
+        case LLM_ARCH_SOLAR:
+            {
385
+                llm = std::make_unique<llm_build_solar>(*this, params, gf);
386
+            } break;
387
388
         case LLM_ARCH_WAVTOKENIZER_DEC:
             {
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
                 llm = std::make_unique<llm_build_wavtokenizer_dec>(*this, params, gf);
@@ -13139,6 +13345,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
         case LLM_ARCH_GRANITE:
         case LLM_ARCH_GRANITE_MOE:
         case LLM_ARCH_CHAMELEON:
+        case LLM_ARCH_SOLAR:
         case LLM_ARCH_BAILINGMOE:
             return LLAMA_ROPE_TYPE_NORM;
 
diff --git a/src/llama-model.h b/src/llama-model.h
index 0f18dac1..e08d4ae4 100644
--- a/src/llama-model.h
+++ b/src/llama-model.h
@@ -62,6 +62,7 @@ enum llm_type {
     LLM_TYPE_15B,
     LLM_TYPE_16B,
     LLM_TYPE_20B,
+    LLM_TYPE_22B,
     LLM_TYPE_30B,
     LLM_TYPE_32B,
     LLM_TYPE_34B,
@@ -305,6 +306,8 @@ struct llama_layer {
     struct ggml_tensor * ffn_up_scale   = nullptr;
     struct ggml_tensor * ffn_down_scale = nullptr;
 
+    struct ggml_tensor * bskcn_tv = nullptr;
+
     struct llama_layer_posnet posnet;
 
     struct llama_layer_convnext convnext;