"vscode:/vscode.git/clone" did not exist on "482944d94775e4963f781d76aa00ac7bbbc6b543"
passkey.cpp 8.17 KB
Newer Older
xuxzh1's avatar
update  
xuxzh1 committed
1
#include "arg.h"
xuxzh1's avatar
init  
xuxzh1 committed
2
#include "common.h"
xuxzh1's avatar
update  
xuxzh1 committed
3
#include "log.h"
xuxzh1's avatar
init  
xuxzh1 committed
4
5
6
7
8
9
10
#include "llama.h"

#include <cmath>
#include <cstdio>
#include <string>
#include <vector>

xuxzh1's avatar
update  
xuxzh1 committed
11
12
13
14
static void print_usage(int, char ** argv) {
    LOG("\nexample usage:\n");
    LOG("\n    %s -m model.gguf --junk 250 --pos 90 --keep 32 --grp-attn-n 2 [--seed 1234]\n", argv[0]);
    LOG("\n");
xuxzh1's avatar
init  
xuxzh1 committed
15
16
17
}

int main(int argc, char ** argv) {
xuxzh1's avatar
update  
xuxzh1 committed
18
    common_params params;
xuxzh1's avatar
init  
xuxzh1 committed
19
20
21
22
23

    params.n_junk = 250;
    params.n_keep = 32;
    params.i_pos  = -1;

xuxzh1's avatar
update  
xuxzh1 committed
24
    if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PASSKEY, print_usage)) {
xuxzh1's avatar
init  
xuxzh1 committed
25
26
27
        return 1;
    }

xuxzh1's avatar
update  
xuxzh1 committed
28
    common_init();
xuxzh1's avatar
init  
xuxzh1 committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63

    int n_junk = params.n_junk;
    int n_keep = params.n_keep;
    int n_grp  = params.grp_attn_n;
    int i_pos  = params.i_pos;

    if (i_pos == -1) {
        i_pos = rand() % n_junk;
    }

    const std::string prompt_prefix = "There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there.";
    const std::string prompt_suffix = " What is the pass key? The pass key is";

    // generate junk text
    params.prompt = prompt_prefix;

    const int passkey = rand() % 50000 + 1;

    for (int i = 0; i < n_junk; i++) {
        if (i % n_junk == i_pos) {
            params.prompt += " The pass key is " + std::to_string(passkey) + ". Remember it. " + std::to_string(passkey) + " is the pass key.";
        }

        params.prompt += " The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again.";
    }

    params.prompt += prompt_suffix;

    // init LLM

    llama_backend_init();
    llama_numa_init(params.numa);

    // initialize the model

xuxzh1's avatar
update  
xuxzh1 committed
64
    llama_model_params model_params = common_model_params_to_llama(params);
xuxzh1's avatar
init  
xuxzh1 committed
65
66
67
68

    llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params);

    if (model == NULL) {
xuxzh1's avatar
update  
xuxzh1 committed
69
        LOG_ERR("%s: unable to load model\n" , __func__);
xuxzh1's avatar
init  
xuxzh1 committed
70
71
72
73
74
        return 1;
    }

    // initialize the context

xuxzh1's avatar
update  
xuxzh1 committed
75
    llama_context_params ctx_params = common_context_params_to_llama(params);
xuxzh1's avatar
init  
xuxzh1 committed
76
77
78
79
80
81
82

    ctx_params.n_ctx = llama_n_ctx_train(model)*n_grp + n_keep;

    GGML_ASSERT(ctx_params.n_batch % n_grp == 0 && "n_batch must be divisible by n_grp");

    llama_context * ctx = llama_new_context_with_model(model, ctx_params);
    if (ctx == NULL) {
xuxzh1's avatar
update  
xuxzh1 committed
83
        LOG_ERR("%s: failed to create the llama_context\n" , __func__);
xuxzh1's avatar
init  
xuxzh1 committed
84
85
86
        return 1;
    }

xuxzh1's avatar
update  
xuxzh1 committed
87
88
89
90
91
92
    auto sparams = llama_sampler_chain_default_params();

    llama_sampler * smpl = llama_sampler_chain_init(sparams);

    llama_sampler_chain_add(smpl, llama_sampler_init_greedy());

xuxzh1's avatar
init  
xuxzh1 committed
93
94
    // tokenize the prompt
    std::vector<llama_token> tokens_list;
xuxzh1's avatar
update  
xuxzh1 committed
95
    tokens_list = common_tokenize(ctx, params.prompt, true);
xuxzh1's avatar
init  
xuxzh1 committed
96
97

    // tokenize the prefix and use it as a sink
xuxzh1's avatar
update  
xuxzh1 committed
98
    const int n_tokens_prefix = common_tokenize(ctx, prompt_prefix, true).size();
xuxzh1's avatar
init  
xuxzh1 committed
99
100
101
102
103
104
105
106
107
108
109
110
111
112

    const int n_tokens_all = tokens_list.size();

    // we leave a margin of 16 tokens for the generated text - it should contain just the passkey
    const int n_predict = 16;

    // total length of the sequences including the prompt
    const int n_len = n_tokens_all + n_predict;

    const int n_ctx       = llama_n_ctx(ctx) - n_keep;
    const int n_kv_req    = llama_n_ctx(ctx);
    const int n_batch     = ctx_params.n_batch;
    const int n_batch_grp = ctx_params.n_batch/n_grp;

xuxzh1's avatar
update  
xuxzh1 committed
113
    LOG_INF("\n%s: n_len = %d, n_ctx = %d, n_kv_req = %d, n_grp = %d, n_batch = %d, n_junk = %d, i_pos = %d\n", __func__, n_len, n_ctx, n_kv_req, n_grp, n_batch, n_junk, i_pos);
xuxzh1's avatar
init  
xuxzh1 committed
114
115
116

    // print the prompt token-by-token

xuxzh1's avatar
update  
xuxzh1 committed
117
118
119
120
    LOG_INF("\n");
    LOG_INF("prefix tokens: %d\n", n_tokens_prefix);
    LOG_INF("prompt tokens: %d\n", n_tokens_all);
    //LOG_INF("prompt: %s\n", params.prompt.c_str());
xuxzh1's avatar
init  
xuxzh1 committed
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139

    llama_batch batch = llama_batch_init(params.n_batch, 0, 1);

    int n_past = 0;

    // fill the KV cache
    for (int i = 0; i < n_ctx; i += n_batch) {
        if (i > 0 && n_grp > 1) {
            // if SelfExtend is enabled, we compress the position from the last batch by a factor of n_grp
            const int ib = i/n_batch - 1;
            const int bd = n_batch_grp*(n_grp - 1);

            llama_kv_cache_seq_add (ctx, 0, n_past - n_batch,         n_past,         ib*bd);
            llama_kv_cache_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp);
            llama_kv_cache_update  (ctx);

            n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
        }

xuxzh1's avatar
update  
xuxzh1 committed
140
        common_batch_clear(batch);
xuxzh1's avatar
init  
xuxzh1 committed
141
142

        for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) {
xuxzh1's avatar
update  
xuxzh1 committed
143
            common_batch_add(batch, tokens_list[i + j], n_past++, { 0 }, false);
xuxzh1's avatar
init  
xuxzh1 committed
144
145
146
147
148
149
150
        }

        if (i + n_batch >= n_tokens_all) {
            batch.logits[batch.n_tokens - 1] = true;
        }

        if (llama_decode(ctx, batch) != 0) {
xuxzh1's avatar
update  
xuxzh1 committed
151
            LOG_INF("%s: llama_decode() failed\n", __func__);
xuxzh1's avatar
init  
xuxzh1 committed
152
153
154
            return 1;
        }

xuxzh1's avatar
update  
xuxzh1 committed
155
        LOG_INF("%s: processed: [%6d, %6d)\n", __func__, i, std::min(i + n_batch, n_tokens_all));
xuxzh1's avatar
init  
xuxzh1 committed
156
157
158
159
160
161
162
163
164

        if (i + n_batch >= n_tokens_all) {
            break;
        }
    }

    for (int i = n_ctx; i < n_tokens_all; i += n_batch) {
        const int n_discard = n_batch;

xuxzh1's avatar
update  
xuxzh1 committed
165
        LOG_INF("%s: shifting KV cache with %d\n", __func__, n_discard);
xuxzh1's avatar
init  
xuxzh1 committed
166
167
168
169
170
171
172
173

        llama_kv_cache_seq_rm (ctx, 0, n_keep            , n_keep + n_discard);
        llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx,  -n_discard);
      //llama_kv_cache_defrag (ctx);
        llama_kv_cache_update (ctx);

        n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;

xuxzh1's avatar
update  
xuxzh1 committed
174
        common_batch_clear(batch);
xuxzh1's avatar
init  
xuxzh1 committed
175
176

        for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) {
xuxzh1's avatar
update  
xuxzh1 committed
177
            common_batch_add(batch, tokens_list[i + j], n_past++, { 0 }, false);
xuxzh1's avatar
init  
xuxzh1 committed
178
179
180
181
182
183
184
        }

        if (i + n_batch >= n_tokens_all) {
            batch.logits[batch.n_tokens - 1] = true;
        }

        if (llama_decode(ctx, batch) != 0) {
xuxzh1's avatar
update  
xuxzh1 committed
185
            LOG_ERR("%s: llama_decode() failed\n", __func__);
xuxzh1's avatar
init  
xuxzh1 committed
186
187
188
            return 1;
        }

xuxzh1's avatar
update  
xuxzh1 committed
189
        LOG_INF("%s: processed: [%6d, %6d)\n", __func__, i, std::min(i + n_batch, n_tokens_all));
xuxzh1's avatar
init  
xuxzh1 committed
190
191
192
193
194
195
    }

    {
        const int n_discard = n_past - n_ctx + n_predict;

        if (n_discard > 0) {
xuxzh1's avatar
update  
xuxzh1 committed
196
            LOG_INF("%s: shifting KV cache with %d to free space for the answer\n", __func__, n_discard);
xuxzh1's avatar
init  
xuxzh1 committed
197
198
199
200
201
202
203
204
205
206

            llama_kv_cache_seq_rm (ctx, 0, n_keep            , n_keep + n_discard);
            llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx,  -n_discard);
          //llama_kv_cache_defrag (ctx);
            llama_kv_cache_update (ctx);

            n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
        }
    }

xuxzh1's avatar
update  
xuxzh1 committed
207
208
209
    LOG_INF("\n");
    LOG_INF("%s: passkey = %d, inserted at position %d / %d (token pos: ~%d)\n", __func__, passkey, i_pos, n_junk, (i_pos * n_tokens_all) / n_junk);
    LOG_INF("\n");
xuxzh1's avatar
init  
xuxzh1 committed
210
211
212
213
214
215

    // main loop

    int n_cur    = n_tokens_all;
    int n_decode = 0;

xuxzh1's avatar
update  
xuxzh1 committed
216
    LOG_INF("%s", prompt_suffix.c_str());
xuxzh1's avatar
init  
xuxzh1 committed
217
218
219
220
221
222

    const auto t_main_start = ggml_time_us();

    while (n_cur <= n_len) {
        // sample the next token
        {
xuxzh1's avatar
update  
xuxzh1 committed
223
            const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1);
xuxzh1's avatar
init  
xuxzh1 committed
224
225
226

            // is it an end of generation?
            if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {
xuxzh1's avatar
update  
xuxzh1 committed
227
                LOG("\n");
xuxzh1's avatar
init  
xuxzh1 committed
228
229
230
231

                break;
            }

xuxzh1's avatar
update  
xuxzh1 committed
232
            LOG("%s", common_token_to_piece(ctx, new_token_id).c_str());
xuxzh1's avatar
init  
xuxzh1 committed
233
234
235
236

            n_decode += 1;

            // prepare the next batch
xuxzh1's avatar
update  
xuxzh1 committed
237
            common_batch_clear(batch);
xuxzh1's avatar
init  
xuxzh1 committed
238
239

            // push this new token for next evaluation
xuxzh1's avatar
update  
xuxzh1 committed
240
            common_batch_add(batch, new_token_id, n_past++, { 0 }, true);
xuxzh1's avatar
init  
xuxzh1 committed
241
242
243
244
245
246
        }

        n_cur += 1;

        // evaluate the current batch with the transformer model
        if (llama_decode(ctx, batch)) {
xuxzh1's avatar
update  
xuxzh1 committed
247
            LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1);
xuxzh1's avatar
init  
xuxzh1 committed
248
249
250
251
            return 1;
        }
    }

xuxzh1's avatar
update  
xuxzh1 committed
252
    LOG("\n");
xuxzh1's avatar
init  
xuxzh1 committed
253
254
255

    const auto t_main_end = ggml_time_us();

xuxzh1's avatar
update  
xuxzh1 committed
256
    LOG_INF("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
xuxzh1's avatar
init  
xuxzh1 committed
257
258
            __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));

xuxzh1's avatar
update  
xuxzh1 committed
259
260
261
262
    LOG("\n");
    llama_perf_context_print(ctx);

    LOG("\n");
xuxzh1's avatar
init  
xuxzh1 committed
263

xuxzh1's avatar
update  
xuxzh1 committed
264
    llama_sampler_free(smpl);
xuxzh1's avatar
init  
xuxzh1 committed
265
266
267
268
269
270
271
272
273
274

    llama_batch_free(batch);

    llama_free(ctx);
    llama_free_model(model);

    llama_backend_free();

    return 0;
}