main.cpp 35.2 KB
Newer Older
xuxzh1's avatar
update  
xuxzh1 committed
1
#include "arg.h"
xuxzh1's avatar
init  
xuxzh1 committed
2
3
#include "common.h"
#include "console.h"
xuxzh1's avatar
update  
xuxzh1 committed
4
5
#include "log.h"
#include "sampling.h"
xuxzh1's avatar
init  
xuxzh1 committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
#include "llama.h"

#include <cassert>
#include <cstdio>
#include <cstring>
#include <ctime>
#include <fstream>
#include <iostream>
#include <sstream>
#include <string>
#include <vector>

#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
#include <signal.h>
#include <unistd.h>
#elif defined (_WIN32)
#define WIN32_LEAN_AND_MEAN
#ifndef NOMINMAX
#define NOMINMAX
#endif
#include <windows.h>
#include <signal.h>
#endif

#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
#endif

static llama_context           ** g_ctx;
static llama_model             ** g_model;
xuxzh1's avatar
update  
xuxzh1 committed
36
37
static common_sampler          ** g_smpl;
static common_params            * g_params;
xuxzh1's avatar
init  
xuxzh1 committed
38
39
40
41
42
43
static std::vector<llama_token> * g_input_tokens;
static std::ostringstream       * g_output_ss;
static std::vector<llama_token> * g_output_tokens;
static bool is_interacting  = false;
static bool need_insert_eot = false;

xuxzh1's avatar
update  
xuxzh1 committed
44
45
46
47
48
49
50
51
52
static void print_usage(int argc, char ** argv) {
    (void) argc;

    LOG("\nexample usage:\n");
    LOG("\n  text generation:     %s -m your_model.gguf -p \"I believe the meaning of life is\" -n 128\n", argv[0]);
    LOG("\n  chat (conversation): %s -m your_model.gguf -p \"You are a helpful assistant\" -cnv\n", argv[0]);
    LOG("\n");
}

xuxzh1's avatar
init  
xuxzh1 committed
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
static bool file_exists(const std::string & path) {
    std::ifstream f(path.c_str());
    return f.good();
}

static bool file_is_empty(const std::string & path) {
    std::ifstream f;
    f.exceptions(std::ifstream::failbit | std::ifstream::badbit);
    f.open(path.c_str(), std::ios::in | std::ios::binary | std::ios::ate);
    return f.tellg() == 0;
}

#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
static void sigint_handler(int signo) {
    if (signo == SIGINT) {
        if (!is_interacting && g_params->interactive) {
            is_interacting  = true;
            need_insert_eot = true;
        } else {
            console::cleanup();
xuxzh1's avatar
update  
xuxzh1 committed
73
74
75
76
77
78
79
            LOG("\n");
            common_perf_print(*g_ctx, *g_smpl);

            // make sure all logs are flushed
            LOG("Interrupted by user\n");
            common_log_pause(common_log_main());

xuxzh1's avatar
init  
xuxzh1 committed
80
81
82
83
84
85
            _exit(130);
        }
    }
}
#endif

xuxzh1's avatar
update  
xuxzh1 committed
86
87
88
static std::string chat_add_and_format(struct llama_model * model, std::vector<common_chat_msg> & chat_msgs, const std::string & role, const std::string & content) {
    common_chat_msg new_msg{role, content};
    auto formatted = common_chat_format_single(model, g_params->chat_template, chat_msgs, new_msg, role == "user");
xuxzh1's avatar
init  
xuxzh1 committed
89
    chat_msgs.push_back({role, content});
xuxzh1's avatar
update  
xuxzh1 committed
90
    LOG_DBG("formatted: '%s'\n", formatted.c_str());
xuxzh1's avatar
init  
xuxzh1 committed
91
92
93
94
    return formatted;
}

int main(int argc, char ** argv) {
xuxzh1's avatar
update  
xuxzh1 committed
95
    common_params params;
xuxzh1's avatar
init  
xuxzh1 committed
96
    g_params = &params;
xuxzh1's avatar
update  
xuxzh1 committed
97
    if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_MAIN, print_usage)) {
xuxzh1's avatar
init  
xuxzh1 committed
98
99
100
        return 1;
    }

xuxzh1's avatar
update  
xuxzh1 committed
101
    common_init();
xuxzh1's avatar
init  
xuxzh1 committed
102

xuxzh1's avatar
update  
xuxzh1 committed
103
    auto & sparams = params.sampling;
xuxzh1's avatar
init  
xuxzh1 committed
104
105
106
107
108
109
110

    // save choice to use color for later
    // (note for later: this is a slightly awkward choice)
    console::init(params.simple_io, params.use_color);
    atexit([]() { console::cleanup(); });

    if (params.logits_all) {
xuxzh1's avatar
update  
xuxzh1 committed
111
112
113
        LOG_ERR("************\n");
        LOG_ERR("%s: please use the 'perplexity' tool for perplexity calculations\n", __func__);
        LOG_ERR("************\n\n");
xuxzh1's avatar
init  
xuxzh1 committed
114
115
116
117
118

        return 0;
    }

    if (params.embedding) {
xuxzh1's avatar
update  
xuxzh1 committed
119
120
121
        LOG_ERR("************\n");
        LOG_ERR("%s: please use the 'embedding' tool for embedding calculations\n", __func__);
        LOG_ERR("************\n\n");
xuxzh1's avatar
init  
xuxzh1 committed
122
123
124
125
126

        return 0;
    }

    if (params.n_ctx != 0 && params.n_ctx < 8) {
xuxzh1's avatar
update  
xuxzh1 committed
127
        LOG_WRN("%s: warning: minimum context size is 8, using minimum size.\n", __func__);
xuxzh1's avatar
init  
xuxzh1 committed
128
129
130
131
        params.n_ctx = 8;
    }

    if (params.rope_freq_base != 0.0) {
xuxzh1's avatar
update  
xuxzh1 committed
132
        LOG_WRN("%s: warning: changing RoPE frequency base to %g.\n", __func__, params.rope_freq_base);
xuxzh1's avatar
init  
xuxzh1 committed
133
134
135
    }

    if (params.rope_freq_scale != 0.0) {
xuxzh1's avatar
update  
xuxzh1 committed
136
        LOG_WRN("%s: warning: scaling RoPE frequency by %g.\n", __func__, params.rope_freq_scale);
xuxzh1's avatar
init  
xuxzh1 committed
137
138
    }

xuxzh1's avatar
update  
xuxzh1 committed
139
    LOG_INF("%s: llama backend init\n", __func__);
xuxzh1's avatar
init  
xuxzh1 committed
140
141
142
143

    llama_backend_init();
    llama_numa_init(params.numa);

xuxzh1's avatar
update  
xuxzh1 committed
144
145
146
147
148
149
    llama_model * model = nullptr;
    llama_context * ctx = nullptr;
    common_sampler * smpl = nullptr;

    std::vector<common_chat_msg> chat_msgs;

xuxzh1's avatar
init  
xuxzh1 committed
150
151
    g_model = &model;
    g_ctx = &ctx;
xuxzh1's avatar
update  
xuxzh1 committed
152
    g_smpl = &smpl;
xuxzh1's avatar
init  
xuxzh1 committed
153
154

    // load the model and apply lora adapter, if any
xuxzh1's avatar
update  
xuxzh1 committed
155
156
    LOG_INF("%s: load the model and apply lora adapter, if any\n", __func__);
    common_init_result llama_init = common_init_from_params(params);
xuxzh1's avatar
init  
xuxzh1 committed
157
158
159
160
161

    model = llama_init.model;
    ctx = llama_init.context;

    if (model == NULL) {
xuxzh1's avatar
update  
xuxzh1 committed
162
        LOG_ERR("%s: error: unable to load model\n", __func__);
xuxzh1's avatar
init  
xuxzh1 committed
163
164
165
        return 1;
    }

xuxzh1's avatar
update  
xuxzh1 committed
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
    LOG_INF("%s: llama threadpool init, n_threads = %d\n", __func__, (int) params.cpuparams.n_threads);

    struct ggml_threadpool_params tpp_batch =
            ggml_threadpool_params_from_cpu_params(params.cpuparams_batch);
    struct ggml_threadpool_params tpp =
            ggml_threadpool_params_from_cpu_params(params.cpuparams);

    set_process_priority(params.cpuparams.priority);

    struct ggml_threadpool * threadpool_batch = NULL;
    if (!ggml_threadpool_params_match(&tpp, &tpp_batch)) {
        threadpool_batch = ggml_threadpool_new(&tpp_batch);
        if (!threadpool_batch) {
            LOG_ERR("%s: batch threadpool create failed : n_threads %d\n", __func__, tpp_batch.n_threads);
            return 1;
        }

        // Start the non-batch threadpool in the paused state
        tpp.paused = true;
    }

    struct ggml_threadpool * threadpool = ggml_threadpool_new(&tpp);
    if (!threadpool) {
        LOG_ERR("%s: threadpool create failed : n_threads %d\n", __func__, tpp.n_threads);
        return 1;
    }

    llama_attach_threadpool(ctx, threadpool, threadpool_batch);

xuxzh1's avatar
init  
xuxzh1 committed
195
196
197
198
    const int n_ctx_train = llama_n_ctx_train(model);
    const int n_ctx = llama_n_ctx(ctx);

    if (n_ctx > n_ctx_train) {
xuxzh1's avatar
update  
xuxzh1 committed
199
        LOG_WRN("%s: model was trained on only %d context tokens (%d specified)\n", __func__, n_ctx_train, n_ctx);
xuxzh1's avatar
init  
xuxzh1 committed
200
201
202
203
204
    }

    // print chat template example in conversation mode
    if (params.conversation) {
        if (params.enable_chat_template) {
xuxzh1's avatar
update  
xuxzh1 committed
205
            LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(model, params.chat_template).c_str());
xuxzh1's avatar
init  
xuxzh1 committed
206
        } else {
xuxzh1's avatar
update  
xuxzh1 committed
207
            LOG_INF("%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__);
xuxzh1's avatar
init  
xuxzh1 committed
208
209
210
211
212
        }
    }

    // print system information
    {
xuxzh1's avatar
update  
xuxzh1 committed
213
214
215
        LOG_INF("\n");
        LOG_INF("%s\n", common_params_get_system_info(params).c_str());
        LOG_INF("\n");
xuxzh1's avatar
init  
xuxzh1 committed
216
217
218
219
220
221
    }

    std::string path_session = params.path_prompt_cache;
    std::vector<llama_token> session_tokens;

    if (!path_session.empty()) {
xuxzh1's avatar
update  
xuxzh1 committed
222
        LOG_INF("%s: attempting to load saved session from '%s'\n", __func__, path_session.c_str());
xuxzh1's avatar
init  
xuxzh1 committed
223
        if (!file_exists(path_session)) {
xuxzh1's avatar
update  
xuxzh1 committed
224
            LOG_INF("%s: session file does not exist, will create.\n", __func__);
xuxzh1's avatar
init  
xuxzh1 committed
225
        } else if (file_is_empty(path_session)) {
xuxzh1's avatar
update  
xuxzh1 committed
226
            LOG_INF("%s: The session file is empty. A new session will be initialized.\n", __func__);
xuxzh1's avatar
init  
xuxzh1 committed
227
228
229
230
231
        } else {
            // The file exists and is not empty
            session_tokens.resize(n_ctx);
            size_t n_token_count_out = 0;
            if (!llama_state_load_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.capacity(), &n_token_count_out)) {
xuxzh1's avatar
update  
xuxzh1 committed
232
                LOG_ERR("%s: failed to load session file '%s'\n", __func__, path_session.c_str());
xuxzh1's avatar
init  
xuxzh1 committed
233
234
235
                return 1;
            }
            session_tokens.resize(n_token_count_out);
xuxzh1's avatar
update  
xuxzh1 committed
236
            LOG_INF("%s: loaded a session with prompt size of %d tokens\n", __func__, (int)session_tokens.size());
xuxzh1's avatar
init  
xuxzh1 committed
237
238
239
        }
    }

xuxzh1's avatar
update  
xuxzh1 committed
240
    const bool add_bos = llama_add_bos_token(model);
xuxzh1's avatar
init  
xuxzh1 committed
241
    if (!llama_model_has_encoder(model)) {
xuxzh1's avatar
update  
xuxzh1 committed
242
        GGML_ASSERT(!llama_add_eos_token(model));
xuxzh1's avatar
init  
xuxzh1 committed
243
    }
xuxzh1's avatar
update  
xuxzh1 committed
244
245

    LOG_DBG("n_ctx: %d, add_bos: %d\n", n_ctx, add_bos);
xuxzh1's avatar
init  
xuxzh1 committed
246
247
248
249
250
251
252
253

    std::vector<llama_token> embd_inp;

    {
        auto prompt = (params.conversation && params.enable_chat_template && !params.prompt.empty())
            ? chat_add_and_format(model, chat_msgs, "system", params.prompt) // format the system prompt in conversation mode
            : params.prompt;
        if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) {
xuxzh1's avatar
update  
xuxzh1 committed
254
255
            LOG_DBG("tokenize the prompt\n");
            embd_inp = common_tokenize(ctx, prompt, true, true);
xuxzh1's avatar
init  
xuxzh1 committed
256
        } else {
xuxzh1's avatar
update  
xuxzh1 committed
257
            LOG_DBG("use session tokens\n");
xuxzh1's avatar
init  
xuxzh1 committed
258
259
260
            embd_inp = session_tokens;
        }

xuxzh1's avatar
update  
xuxzh1 committed
261
262
        LOG_DBG("prompt: \"%s\"\n", prompt.c_str());
        LOG_DBG("tokens: %s\n", string_from(ctx, embd_inp).c_str());
xuxzh1's avatar
init  
xuxzh1 committed
263
264
265
266
267
268
    }

    // Should not run without any tokens
    if (embd_inp.empty()) {
        if (add_bos) {
            embd_inp.push_back(llama_token_bos(model));
xuxzh1's avatar
update  
xuxzh1 committed
269
            LOG_WRN("embd_inp was considered empty and bos was added: %s\n", string_from(ctx, embd_inp).c_str());
xuxzh1's avatar
init  
xuxzh1 committed
270
        } else {
xuxzh1's avatar
update  
xuxzh1 committed
271
            LOG_ERR("input is empty\n");
xuxzh1's avatar
init  
xuxzh1 committed
272
273
274
275
276
277
            return -1;
        }
    }

    // Tokenize negative prompt
    if ((int) embd_inp.size() > n_ctx - 4) {
xuxzh1's avatar
update  
xuxzh1 committed
278
        LOG_ERR("%s: prompt is too long (%d tokens, max %d)\n", __func__, (int) embd_inp.size(), n_ctx - 4);
xuxzh1's avatar
init  
xuxzh1 committed
279
280
281
282
283
284
285
286
287
288
289
290
291
        return 1;
    }

    // debug message about similarity of saved session, if applicable
    size_t n_matching_session_tokens = 0;
    if (!session_tokens.empty()) {
        for (llama_token id : session_tokens) {
            if (n_matching_session_tokens >= embd_inp.size() || id != embd_inp[n_matching_session_tokens]) {
                break;
            }
            n_matching_session_tokens++;
        }
        if (params.prompt.empty() && n_matching_session_tokens == embd_inp.size()) {
xuxzh1's avatar
update  
xuxzh1 committed
292
            LOG_INF("%s: using full prompt from session file\n", __func__);
xuxzh1's avatar
init  
xuxzh1 committed
293
        } else if (n_matching_session_tokens >= embd_inp.size()) {
xuxzh1's avatar
update  
xuxzh1 committed
294
            LOG_INF("%s: session file has exact match for prompt!\n", __func__);
xuxzh1's avatar
init  
xuxzh1 committed
295
        } else if (n_matching_session_tokens < (embd_inp.size() / 2)) {
xuxzh1's avatar
update  
xuxzh1 committed
296
297
            LOG_WRN("%s: session file has low similarity to prompt (%zu / %zu tokens); will mostly be reevaluated\n",
                    __func__, n_matching_session_tokens, embd_inp.size());
xuxzh1's avatar
init  
xuxzh1 committed
298
        } else {
xuxzh1's avatar
update  
xuxzh1 committed
299
300
            LOG_INF("%s: session file matches %zu / %zu tokens of prompt\n",
                    __func__, n_matching_session_tokens, embd_inp.size());
xuxzh1's avatar
init  
xuxzh1 committed
301
302
303
304
305
306
        }

        // remove any "future" tokens that we might have inherited from the previous session
        llama_kv_cache_seq_rm(ctx, -1, n_matching_session_tokens, -1);
    }

xuxzh1's avatar
update  
xuxzh1 committed
307
308
    LOG_DBG("recalculate the cached logits (check): embd_inp.size() %zu, n_matching_session_tokens %zu, embd_inp.size() %zu, session_tokens.size() %zu\n",
         embd_inp.size(), n_matching_session_tokens, embd_inp.size(), session_tokens.size());
xuxzh1's avatar
init  
xuxzh1 committed
309
310
311
312

    // if we will use the cache for the full prompt without reaching the end of the cache, force
    // reevaluation of the last token to recalculate the cached logits
    if (!embd_inp.empty() && n_matching_session_tokens == embd_inp.size() && session_tokens.size() > embd_inp.size()) {
xuxzh1's avatar
update  
xuxzh1 committed
313
        LOG_DBG("recalculate the cached logits (do): session_tokens.resize( %zu )\n", embd_inp.size() - 1);
xuxzh1's avatar
init  
xuxzh1 committed
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334

        session_tokens.resize(embd_inp.size() - 1);
    }

    // number of tokens to keep when resetting context
    if (params.n_keep < 0 || params.n_keep > (int) embd_inp.size()) {
        params.n_keep = (int)embd_inp.size();
    } else {
        params.n_keep += add_bos; // always keep the BOS token
    }

    if (params.conversation) {
        params.interactive_first = true;
    }

    // enable interactive mode if interactive start is specified
    if (params.interactive_first) {
        params.interactive = true;
    }

    if (params.verbose_prompt) {
xuxzh1's avatar
update  
xuxzh1 committed
335
336
        LOG_INF("%s: prompt: '%s'\n", __func__, params.prompt.c_str());
        LOG_INF("%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
xuxzh1's avatar
init  
xuxzh1 committed
337
        for (int i = 0; i < (int) embd_inp.size(); i++) {
xuxzh1's avatar
update  
xuxzh1 committed
338
            LOG_INF("%6d -> '%s'\n", embd_inp[i], common_token_to_piece(ctx, embd_inp[i]).c_str());
xuxzh1's avatar
init  
xuxzh1 committed
339
340
341
        }

        if (params.n_keep > add_bos) {
xuxzh1's avatar
update  
xuxzh1 committed
342
            LOG_INF("%s: static prompt based on n_keep: '", __func__);
xuxzh1's avatar
init  
xuxzh1 committed
343
            for (int i = 0; i < params.n_keep; i++) {
xuxzh1's avatar
update  
xuxzh1 committed
344
                LOG_CNT("%s", common_token_to_piece(ctx, embd_inp[i]).c_str());
xuxzh1's avatar
init  
xuxzh1 committed
345
            }
xuxzh1's avatar
update  
xuxzh1 committed
346
            LOG_CNT("'\n");
xuxzh1's avatar
init  
xuxzh1 committed
347
        }
xuxzh1's avatar
update  
xuxzh1 committed
348
        LOG_INF("\n");
xuxzh1's avatar
init  
xuxzh1 committed
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
    }

    // ctrl+C handling
    {
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
        struct sigaction sigint_action;
        sigint_action.sa_handler = sigint_handler;
        sigemptyset (&sigint_action.sa_mask);
        sigint_action.sa_flags = 0;
        sigaction(SIGINT, &sigint_action, NULL);
#elif defined (_WIN32)
        auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
            return (ctrl_type == CTRL_C_EVENT) ? (sigint_handler(SIGINT), true) : false;
        };
        SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
#endif
    }

    if (params.interactive) {
xuxzh1's avatar
update  
xuxzh1 committed
368
        LOG_INF("%s: interactive mode on.\n", __func__);
xuxzh1's avatar
init  
xuxzh1 committed
369
370
371

        if (!params.antiprompt.empty()) {
            for (const auto & antiprompt : params.antiprompt) {
xuxzh1's avatar
update  
xuxzh1 committed
372
                LOG_INF("Reverse prompt: '%s'\n", antiprompt.c_str());
xuxzh1's avatar
init  
xuxzh1 committed
373
                if (params.verbose_prompt) {
xuxzh1's avatar
update  
xuxzh1 committed
374
                    auto tmp = common_tokenize(ctx, antiprompt, false, true);
xuxzh1's avatar
init  
xuxzh1 committed
375
                    for (int i = 0; i < (int) tmp.size(); i++) {
xuxzh1's avatar
update  
xuxzh1 committed
376
                        LOG_INF("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx, tmp[i]).c_str());
xuxzh1's avatar
init  
xuxzh1 committed
377
378
379
380
381
382
                    }
                }
            }
        }

        if (params.input_prefix_bos) {
xuxzh1's avatar
update  
xuxzh1 committed
383
            LOG_INF("Input prefix with BOS\n");
xuxzh1's avatar
init  
xuxzh1 committed
384
385
386
        }

        if (!params.input_prefix.empty()) {
xuxzh1's avatar
update  
xuxzh1 committed
387
            LOG_INF("Input prefix: '%s'\n", params.input_prefix.c_str());
xuxzh1's avatar
init  
xuxzh1 committed
388
            if (params.verbose_prompt) {
xuxzh1's avatar
update  
xuxzh1 committed
389
                auto tmp = common_tokenize(ctx, params.input_prefix, true, true);
xuxzh1's avatar
init  
xuxzh1 committed
390
                for (int i = 0; i < (int) tmp.size(); i++) {
xuxzh1's avatar
update  
xuxzh1 committed
391
                    LOG_INF("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx, tmp[i]).c_str());
xuxzh1's avatar
init  
xuxzh1 committed
392
393
394
395
396
                }
            }
        }

        if (!params.input_suffix.empty()) {
xuxzh1's avatar
update  
xuxzh1 committed
397
            LOG_INF("Input suffix: '%s'\n", params.input_suffix.c_str());
xuxzh1's avatar
init  
xuxzh1 committed
398
            if (params.verbose_prompt) {
xuxzh1's avatar
update  
xuxzh1 committed
399
                auto tmp = common_tokenize(ctx, params.input_suffix, false, true);
xuxzh1's avatar
init  
xuxzh1 committed
400
                for (int i = 0; i < (int) tmp.size(); i++) {
xuxzh1's avatar
update  
xuxzh1 committed
401
                    LOG_INF("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx, tmp[i]).c_str());
xuxzh1's avatar
init  
xuxzh1 committed
402
403
404
405
                }
            }
        }
    }
xuxzh1's avatar
update  
xuxzh1 committed
406
407
408
409
410
411
412
413
414
415
416
417

    smpl = common_sampler_init(model, sparams);
    if (!smpl) {
        LOG_ERR("%s: failed to initialize sampling subsystem\n", __func__);
        return 1;
    }

    LOG_INF("sampler seed: %u\n",     common_sampler_get_seed(smpl));
    LOG_INF("sampler params: \n%s\n", sparams.print().c_str());
    LOG_INF("sampler chain: %s\n",    common_sampler_print(smpl).c_str());

    LOG_INF("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
xuxzh1's avatar
init  
xuxzh1 committed
418
419
420
421
422
423
424
425
426
427
428
429
430

    // group-attention state
    // number of grouped KV tokens so far (used only if params.grp_attn_n > 1)
    int ga_i = 0;

    const int ga_n = params.grp_attn_n;
    const int ga_w = params.grp_attn_w;

    if (ga_n != 1) {
        GGML_ASSERT(ga_n > 0                    && "grp_attn_n must be positive");                     // NOLINT
        GGML_ASSERT(ga_w % ga_n == 0            && "grp_attn_w must be a multiple of grp_attn_n");     // NOLINT
      //GGML_ASSERT(n_ctx_train % ga_w == 0     && "n_ctx_train must be a multiple of grp_attn_w");    // NOLINT
      //GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * grp_attn_n"); // NOLINT
xuxzh1's avatar
update  
xuxzh1 committed
431
        LOG_INF("self-extend: n_ctx_train = %d, grp_attn_n = %d, grp_attn_w = %d\n", n_ctx_train, ga_n, ga_w);
xuxzh1's avatar
init  
xuxzh1 committed
432
    }
xuxzh1's avatar
update  
xuxzh1 committed
433
    LOG_INF("\n");
xuxzh1's avatar
init  
xuxzh1 committed
434
435
436
437
438
439
440
441
442
443
444

    if (params.interactive) {
        const char * control_message;
        if (params.multiline_input) {
            control_message = " - To return control to the AI, end your input with '\\'.\n"
                              " - To return control without starting a new line, end your input with '/'.\n";
        } else {
            control_message = " - Press Return to return control to the AI.\n"
                              " - To return control without starting a new line, end your input with '/'.\n"
                              " - If you want to submit another line, end your input with '\\'.\n";
        }
xuxzh1's avatar
update  
xuxzh1 committed
445
        LOG_INF("== Running in interactive mode. ==\n");
xuxzh1's avatar
init  
xuxzh1 committed
446
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
xuxzh1's avatar
update  
xuxzh1 committed
447
        LOG_INF(       " - Press Ctrl+C to interject at any time.\n");
xuxzh1's avatar
init  
xuxzh1 committed
448
#endif
xuxzh1's avatar
update  
xuxzh1 committed
449
        LOG_INF(       "%s\n", control_message);
xuxzh1's avatar
init  
xuxzh1 committed
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479

        is_interacting = params.interactive_first;
    }

    bool is_antiprompt        = false;
    bool input_echo           = true;
    bool display              = true;
    bool need_to_save_session = !path_session.empty() && n_matching_session_tokens < embd_inp.size();

    int n_past             = 0;
    int n_remain           = params.n_predict;
    int n_consumed         = 0;
    int n_session_consumed = 0;

    std::vector<int>   input_tokens;  g_input_tokens  = &input_tokens;
    std::vector<int>   output_tokens; g_output_tokens = &output_tokens;
    std::ostringstream output_ss;     g_output_ss     = &output_ss;
    std::ostringstream assistant_ss; // for storing current assistant message, used in conversation mode

    // the first thing we will do is to output the prompt, so set color accordingly
    console::set_display(console::prompt);
    display = params.display_prompt;

    std::vector<llama_token> embd;

    // tokenized antiprompts
    std::vector<std::vector<llama_token>> antiprompt_ids;

    antiprompt_ids.reserve(params.antiprompt.size());
    for (const std::string & antiprompt : params.antiprompt) {
xuxzh1's avatar
update  
xuxzh1 committed
480
        antiprompt_ids.emplace_back(::common_tokenize(ctx, antiprompt, false, true));
xuxzh1's avatar
init  
xuxzh1 committed
481
482
483
484
485
486
    }

    if (llama_model_has_encoder(model)) {
        int enc_input_size = embd_inp.size();
        llama_token * enc_input_buf = embd_inp.data();

xuxzh1's avatar
update  
xuxzh1 committed
487
488
        if (llama_encode(ctx, llama_batch_get_one(enc_input_buf, enc_input_size))) {
            LOG_ERR("%s : failed to eval\n", __func__);
xuxzh1's avatar
init  
xuxzh1 committed
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
            return 1;
        }

        llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
        if (decoder_start_token_id == -1) {
            decoder_start_token_id = llama_token_bos(model);
        }

        embd_inp.clear();
        embd_inp.push_back(decoder_start_token_id);
    }

    while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
        // predict
        if (!embd.empty()) {
            // Note: (n_ctx - 4) here is to match the logic for commandline prompt handling via
            // --prompt or --file which uses the same value.
            int max_embd_size = n_ctx - 4;

            // Ensure the input doesn't exceed the context size by truncating embd if necessary.
            if ((int) embd.size() > max_embd_size) {
                const int skipped_tokens = (int) embd.size() - max_embd_size;
                embd.resize(max_embd_size);

                console::set_display(console::error);
xuxzh1's avatar
update  
xuxzh1 committed
514
                LOG_WRN("<<input too long: skipped %d token%s>>", skipped_tokens, skipped_tokens != 1 ? "s" : "");
xuxzh1's avatar
init  
xuxzh1 committed
515
516
517
518
519
520
521
522
                console::set_display(console::reset);
            }

            if (ga_n == 1) {
                // infinite text generation via context shifting
                // if we run out of context:
                // - take the n_keep first tokens from the original prompt (via n_past)
                // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches
xuxzh1's avatar
update  
xuxzh1 committed
523
524
525
526
527
528
529

                if (n_past + (int) embd.size() >= n_ctx) {
                    if (!params.ctx_shift){
                        LOG_DBG("\n\n%s: context full and context shift is disabled => stopping\n", __func__);
                        break;
                    }

xuxzh1's avatar
init  
xuxzh1 committed
530
                    if (params.n_predict == -2) {
xuxzh1's avatar
update  
xuxzh1 committed
531
                        LOG_DBG("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict);
xuxzh1's avatar
init  
xuxzh1 committed
532
533
534
535
536
537
                        break;
                    }

                    const int n_left    = n_past - params.n_keep;
                    const int n_discard = n_left/2;

xuxzh1's avatar
update  
xuxzh1 committed
538
                    LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
xuxzh1's avatar
init  
xuxzh1 committed
539
540
541
542
543
544
545
                            n_past, n_left, n_ctx, params.n_keep, n_discard);

                    llama_kv_cache_seq_rm (ctx, 0, params.n_keep            , params.n_keep + n_discard);
                    llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);

                    n_past -= n_discard;

xuxzh1's avatar
update  
xuxzh1 committed
546
                    LOG_DBG("after swap: n_past = %d\n", n_past);
xuxzh1's avatar
init  
xuxzh1 committed
547

xuxzh1's avatar
update  
xuxzh1 committed
548
                    LOG_DBG("embd: %s\n", string_from(ctx, embd).c_str());
xuxzh1's avatar
init  
xuxzh1 committed
549

xuxzh1's avatar
update  
xuxzh1 committed
550
                    LOG_DBG("clear session path\n");
xuxzh1's avatar
init  
xuxzh1 committed
551
552
553
554
555
556
557
558
559
                    path_session.clear();
                }
            } else {
                // context extension via Self-Extend
                while (n_past >= ga_i + ga_w) {
                    const int ib = (ga_n*ga_i)/ga_w;
                    const int bd = (ga_w/ga_n)*(ga_n - 1);
                    const int dd = (ga_w/ga_n) - ib*bd - ga_w;

xuxzh1's avatar
update  
xuxzh1 committed
560
561
562
563
                    LOG_DBG("\n");
                    LOG_DBG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i, n_past, ib*bd, ga_i + ib*bd, n_past + ib*bd);
                    LOG_DBG("div:   [%6d, %6d] / %6d -> [%6d, %6d]\n", ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n, (ga_i + ib*bd)/ga_n, (ga_i + ib*bd + ga_w)/ga_n);
                    LOG_DBG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i + ib*bd + ga_w, n_past + ib*bd, dd, ga_i + ib*bd + ga_w + dd, n_past + ib*bd + dd);
xuxzh1's avatar
init  
xuxzh1 committed
564
565
566
567
568
569
570
571
572

                    llama_kv_cache_seq_add(ctx, 0, ga_i,                n_past,              ib*bd);
                    llama_kv_cache_seq_div(ctx, 0, ga_i + ib*bd,        ga_i + ib*bd + ga_w, ga_n);
                    llama_kv_cache_seq_add(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd,      dd);

                    n_past -= bd;

                    ga_i += ga_w/ga_n;

xuxzh1's avatar
update  
xuxzh1 committed
573
                    LOG_DBG("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", n_past + bd, n_past, ga_i);
xuxzh1's avatar
init  
xuxzh1 committed
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
                }
            }

            // try to reuse a matching prefix from the loaded session instead of re-eval (via n_past)
            if (n_session_consumed < (int) session_tokens.size()) {
                size_t i = 0;
                for ( ; i < embd.size(); i++) {
                    if (embd[i] != session_tokens[n_session_consumed]) {
                        session_tokens.resize(n_session_consumed);
                        break;
                    }

                    n_past++;
                    n_session_consumed++;

                    if (n_session_consumed >= (int) session_tokens.size()) {
                        ++i;
                        break;
                    }
                }
                if (i > 0) {
                    embd.erase(embd.begin(), embd.begin() + i);
                }
            }

            for (int i = 0; i < (int) embd.size(); i += params.n_batch) {
                int n_eval = (int) embd.size() - i;
                if (n_eval > params.n_batch) {
                    n_eval = params.n_batch;
                }

xuxzh1's avatar
update  
xuxzh1 committed
605
                LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str());
xuxzh1's avatar
init  
xuxzh1 committed
606

xuxzh1's avatar
update  
xuxzh1 committed
607
608
                if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval))) {
                    LOG_ERR("%s : failed to eval\n", __func__);
xuxzh1's avatar
init  
xuxzh1 committed
609
610
611
612
613
                    return 1;
                }

                n_past += n_eval;

xuxzh1's avatar
update  
xuxzh1 committed
614
                LOG_DBG("n_past = %d\n", n_past);
xuxzh1's avatar
init  
xuxzh1 committed
615
616
                // Display total tokens alongside total time
                if (params.n_print > 0 && n_past % params.n_print == 0) {
xuxzh1's avatar
update  
xuxzh1 committed
617
                    LOG_DBG("\n\033[31mTokens consumed so far = %d / %d \033[0m\n", n_past, n_ctx);
xuxzh1's avatar
init  
xuxzh1 committed
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
                }
            }

            if (!embd.empty() && !path_session.empty()) {
                session_tokens.insert(session_tokens.end(), embd.begin(), embd.end());
                n_session_consumed = session_tokens.size();
            }
        }

        embd.clear();

        if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
            // optionally save the session on first sample (for faster prompt loading next time)
            if (!path_session.empty() && need_to_save_session && !params.prompt_cache_ro) {
                need_to_save_session = false;
                llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());

xuxzh1's avatar
update  
xuxzh1 committed
635
                LOG_DBG("saved session to %s\n", path_session.c_str());
xuxzh1's avatar
init  
xuxzh1 committed
636
637
            }

xuxzh1's avatar
update  
xuxzh1 committed
638
            const llama_token id = common_sampler_sample(smpl, ctx, -1);
xuxzh1's avatar
init  
xuxzh1 committed
639

xuxzh1's avatar
update  
xuxzh1 committed
640
            common_sampler_accept(smpl, id, /* accept_grammar= */ true);
xuxzh1's avatar
init  
xuxzh1 committed
641

xuxzh1's avatar
update  
xuxzh1 committed
642
            // LOG_DBG("last: %s\n", string_from(ctx, smpl->prev.to_vector()).c_str());
xuxzh1's avatar
init  
xuxzh1 committed
643
644
645
646
647
648
649
650
651

            embd.push_back(id);

            // echo this to console
            input_echo = true;

            // decrement remaining sampling budget
            --n_remain;

xuxzh1's avatar
update  
xuxzh1 committed
652
            LOG_DBG("n_remain: %d\n", n_remain);
xuxzh1's avatar
init  
xuxzh1 committed
653
654
        } else {
            // some user input remains from prompt or interaction, forward it to processing
xuxzh1's avatar
update  
xuxzh1 committed
655
            LOG_DBG("embd_inp.size(): %d, n_consumed: %d\n", (int) embd_inp.size(), n_consumed);
xuxzh1's avatar
init  
xuxzh1 committed
656
657
658
659
660
            while ((int) embd_inp.size() > n_consumed) {
                embd.push_back(embd_inp[n_consumed]);

                // push the prompt in the sampling context in order to apply repetition penalties later
                // for the prompt, we don't apply grammar rules
xuxzh1's avatar
update  
xuxzh1 committed
661
                common_sampler_accept(smpl, embd_inp[n_consumed], /* accept_grammar= */ false);
xuxzh1's avatar
init  
xuxzh1 committed
662
663
664
665
666
667
668
669
670
671
672

                ++n_consumed;
                if ((int) embd.size() >= params.n_batch) {
                    break;
                }
            }
        }

        // display text
        if (input_echo && display) {
            for (auto id : embd) {
xuxzh1's avatar
update  
xuxzh1 committed
673
                const std::string token_str = common_token_to_piece(ctx, id, params.special);
xuxzh1's avatar
init  
xuxzh1 committed
674
675

                // Console/Stream Output
xuxzh1's avatar
update  
xuxzh1 committed
676
                LOG("%s", token_str.c_str());
xuxzh1's avatar
init  
xuxzh1 committed
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701

                // Record Displayed Tokens To Log
                // Note: Generated tokens are created one by one hence this check
                if (embd.size() > 1) {
                    // Incoming Requested Tokens
                    input_tokens.push_back(id);
                } else {
                    // Outgoing Generated Tokens
                    output_tokens.push_back(id);
                    output_ss << token_str;
                }
            }
        }

        // reset color to default if there is no pending user input
        if (input_echo && (int) embd_inp.size() == n_consumed) {
            console::set_display(console::reset);
            display = true;
        }

        // if not currently processing queued inputs;
        if ((int) embd_inp.size() <= n_consumed) {
            // check for reverse prompt in the last n_prev tokens
            if (!params.antiprompt.empty()) {
                const int n_prev = 32;
xuxzh1's avatar
update  
xuxzh1 committed
702
                const std::string last_output = common_sampler_prev_str(smpl, ctx, n_prev);
xuxzh1's avatar
init  
xuxzh1 committed
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723

                is_antiprompt = false;
                // Check if each of the reverse prompts appears at the end of the output.
                // If we're not running interactively, the reverse prompt might be tokenized with some following characters
                // so we'll compensate for that by widening the search window a bit.
                for (std::string & antiprompt : params.antiprompt) {
                    size_t extra_padding = params.interactive ? 0 : 2;
                    size_t search_start_pos = last_output.length() > static_cast<size_t>(antiprompt.length() + extra_padding)
                        ? last_output.length() - static_cast<size_t>(antiprompt.length() + extra_padding)
                        : 0;

                    if (last_output.find(antiprompt, search_start_pos) != std::string::npos) {
                        if (params.interactive) {
                            is_interacting = true;
                        }
                        is_antiprompt = true;
                        break;
                    }
                }

                // check for reverse prompt using special tokens
xuxzh1's avatar
update  
xuxzh1 committed
724
                llama_token last_token = common_sampler_last(smpl);
xuxzh1's avatar
init  
xuxzh1 committed
725
726
727
728
729
730
731
732
733
734
735
                for (std::vector<llama_token> ids : antiprompt_ids) {
                    if (ids.size() == 1 && last_token == ids[0]) {
                        if (params.interactive) {
                            is_interacting = true;
                        }
                        is_antiprompt = true;
                        break;
                    }
                }

                if (is_antiprompt) {
xuxzh1's avatar
update  
xuxzh1 committed
736
                    LOG_DBG("found antiprompt: %s\n", last_output.c_str());
xuxzh1's avatar
init  
xuxzh1 committed
737
738
739
740
                }
            }

            // deal with end of generation tokens in interactive mode
xuxzh1's avatar
update  
xuxzh1 committed
741
742
            if (llama_token_is_eog(model, common_sampler_last(smpl))) {
                LOG_DBG("found an EOG token\n");
xuxzh1's avatar
init  
xuxzh1 committed
743
744
745
746

                if (params.interactive) {
                    if (!params.antiprompt.empty()) {
                        // tokenize and inject first reverse prompt
xuxzh1's avatar
update  
xuxzh1 committed
747
                        const auto first_antiprompt = common_tokenize(ctx, params.antiprompt.front(), false, true);
xuxzh1's avatar
init  
xuxzh1 committed
748
749
750
751
752
753
754
755
                        embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end());
                        is_antiprompt = true;
                    }

                    if (params.enable_chat_template) {
                        chat_add_and_format(model, chat_msgs, "assistant", assistant_ss.str());
                    }
                    is_interacting = true;
xuxzh1's avatar
update  
xuxzh1 committed
756
                    LOG("\n");
xuxzh1's avatar
init  
xuxzh1 committed
757
758
759
760
761
                }
            }

            // if current token is not EOG, we add it to current assistant message
            if (params.conversation) {
xuxzh1's avatar
update  
xuxzh1 committed
762
763
                const auto id = common_sampler_last(smpl);
                assistant_ss << common_token_to_piece(ctx, id, false);
xuxzh1's avatar
init  
xuxzh1 committed
764
765
766
            }

            if (n_past > 0 && is_interacting) {
xuxzh1's avatar
update  
xuxzh1 committed
767
                LOG_DBG("waiting for user input\n");
xuxzh1's avatar
init  
xuxzh1 committed
768
769

                if (params.conversation) {
xuxzh1's avatar
update  
xuxzh1 committed
770
                    LOG("\n> ");
xuxzh1's avatar
init  
xuxzh1 committed
771
772
773
                }

                if (params.input_prefix_bos) {
xuxzh1's avatar
update  
xuxzh1 committed
774
                    LOG_DBG("adding input prefix BOS token\n");
xuxzh1's avatar
init  
xuxzh1 committed
775
776
777
778
779
                    embd_inp.push_back(llama_token_bos(model));
                }

                std::string buffer;
                if (!params.input_prefix.empty() && !params.conversation) {
xuxzh1's avatar
update  
xuxzh1 committed
780
781
                    LOG_DBG("appending input prefix: '%s'\n", params.input_prefix.c_str());
                    LOG("%s", params.input_prefix.c_str());
xuxzh1's avatar
init  
xuxzh1 committed
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
                }

                // color user input only
                console::set_display(console::user_input);
                display = params.display_prompt;

                std::string line;
                bool another_line = true;
                do {
                    another_line = console::readline(line, params.multiline_input);
                    buffer += line;
                } while (another_line);

                // done taking input, reset color
                console::set_display(console::reset);
                display = true;

                // Add tokens to embd only if the input buffer is non-empty
                // Entering a empty line lets the user pass control back
                if (buffer.length() > 1) {
                    // append input suffix if any
                    if (!params.input_suffix.empty() && !params.conversation) {
xuxzh1's avatar
update  
xuxzh1 committed
804
805
                        LOG_DBG("appending input suffix: '%s'\n", params.input_suffix.c_str());
                        LOG("%s", params.input_suffix.c_str());
xuxzh1's avatar
init  
xuxzh1 committed
806
807
                    }

xuxzh1's avatar
update  
xuxzh1 committed
808
                    LOG_DBG("buffer: '%s'\n", buffer.c_str());
xuxzh1's avatar
init  
xuxzh1 committed
809
810
811
812
813
814
815
816
817
818
819
820

                    const size_t original_size = embd_inp.size();

                    if (params.escape) {
                        string_process_escapes(buffer);
                    }

                    bool format_chat = params.conversation && params.enable_chat_template;
                    std::string user_inp = format_chat
                        ? chat_add_and_format(model, chat_msgs, "user", std::move(buffer))
                        : std::move(buffer);
                    // TODO: one inconvenient of current chat template implementation is that we can't distinguish between user input and special tokens (prefix/postfix)
xuxzh1's avatar
update  
xuxzh1 committed
821
822
823
                    const auto line_pfx = common_tokenize(ctx, params.input_prefix, false, true);
                    const auto line_inp = common_tokenize(ctx, user_inp,            false, format_chat);
                    const auto line_sfx = common_tokenize(ctx, params.input_suffix, false, true);
xuxzh1's avatar
init  
xuxzh1 committed
824

xuxzh1's avatar
update  
xuxzh1 committed
825
                    LOG_DBG("input tokens: %s\n", string_from(ctx, line_inp).c_str());
xuxzh1's avatar
init  
xuxzh1 committed
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840

                    // if user stop generation mid-way, we must add EOT to finish model's last response
                    if (need_insert_eot && format_chat) {
                        llama_token eot = llama_token_eot(model);
                        embd_inp.push_back(eot == -1 ? llama_token_eos(model) : eot);
                        need_insert_eot = false;
                    }

                    embd_inp.insert(embd_inp.end(), line_pfx.begin(), line_pfx.end());
                    embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
                    embd_inp.insert(embd_inp.end(), line_sfx.begin(), line_sfx.end());

                    for (size_t i = original_size; i < embd_inp.size(); ++i) {
                        const llama_token token = embd_inp[i];
                        output_tokens.push_back(token);
xuxzh1's avatar
update  
xuxzh1 committed
841
                        output_ss << common_token_to_piece(ctx, token);
xuxzh1's avatar
init  
xuxzh1 committed
842
843
844
845
846
847
                    }

                    // reset assistant message
                    assistant_ss.str("");

                    n_remain -= line_inp.size();
xuxzh1's avatar
update  
xuxzh1 committed
848
                    LOG_DBG("n_remain: %d\n", n_remain);
xuxzh1's avatar
init  
xuxzh1 committed
849
                } else {
xuxzh1's avatar
update  
xuxzh1 committed
850
                    LOG_DBG("empty line, passing control back\n");
xuxzh1's avatar
init  
xuxzh1 committed
851
852
853
854
855
856
857
                }

                input_echo = false; // do not echo this again
            }

            if (n_past > 0) {
                if (is_interacting) {
xuxzh1's avatar
update  
xuxzh1 committed
858
                    common_sampler_reset(smpl);
xuxzh1's avatar
init  
xuxzh1 committed
859
860
861
862
863
864
865
                }
                is_interacting = false;
            }
        }

        // end of generation
        if (!embd.empty() && llama_token_is_eog(model, embd.back()) && !(params.interactive)) {
xuxzh1's avatar
update  
xuxzh1 committed
866
            LOG(" [end of text]\n");
xuxzh1's avatar
init  
xuxzh1 committed
867
868
869
870
871
872
873
874
875
876
877
878
            break;
        }

        // In interactive mode, respect the maximum number of tokens and drop back to user input when reached.
        // We skip this logic when n_predict == -1 (infinite) or -2 (stop at context size).
        if (params.interactive && n_remain <= 0 && params.n_predict >= 0) {
            n_remain = params.n_predict;
            is_interacting = true;
        }
    }

    if (!path_session.empty() && params.prompt_cache_all && !params.prompt_cache_ro) {
xuxzh1's avatar
update  
xuxzh1 committed
879
        LOG("\n%s: saving final output to session file '%s'\n", __func__, path_session.c_str());
xuxzh1's avatar
init  
xuxzh1 committed
880
881
882
        llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
    }

xuxzh1's avatar
update  
xuxzh1 committed
883
884
885
886
    LOG("\n\n");
    common_perf_print(ctx, smpl);

    common_sampler_free(smpl);
xuxzh1's avatar
init  
xuxzh1 committed
887
888
889
890
891
892

    llama_free(ctx);
    llama_free_model(model);

    llama_backend_free();

xuxzh1's avatar
update  
xuxzh1 committed
893
894
    ggml_threadpool_free(threadpool);
    ggml_threadpool_free(threadpool_batch);
xuxzh1's avatar
init  
xuxzh1 committed
895
896
897

    return 0;
}