lookup-stats.cpp 5.58 KB
Newer Older
xuxzh1's avatar
update  
xuxzh1 committed
1
#include "arg.h"
xuxzh1's avatar
init  
xuxzh1 committed
2
3
4
#include "common.h"
#include "log.h"
#include "ngram-cache.h"
xuxzh1's avatar
update  
xuxzh1 committed
5
6
#include "llama.h"
#include "ggml.h"
xuxzh1's avatar
init  
xuxzh1 committed
7
8
9

#include <cstdint>
#include <cstdio>
xuxzh1's avatar
update  
xuxzh1 committed
10
#include <cinttypes>
xuxzh1's avatar
init  
xuxzh1 committed
11
12
13
14
15
#include <fstream>
#include <string>
#include <vector>

int main(int argc, char ** argv){
xuxzh1's avatar
update  
xuxzh1 committed
16
    common_params params;
xuxzh1's avatar
init  
xuxzh1 committed
17

xuxzh1's avatar
update  
xuxzh1 committed
18
    if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_LOOKUP)) {
xuxzh1's avatar
init  
xuxzh1 committed
19
20
21
        return 1;
    }

xuxzh1's avatar
update  
xuxzh1 committed
22
23
24
    common_init();

    const int n_draft = params.speculative.n_max;
xuxzh1's avatar
init  
xuxzh1 committed
25
26
27
28
29
30

    // init llama.cpp
    llama_backend_init();
    llama_numa_init(params.numa);

    // load the model
xuxzh1's avatar
update  
xuxzh1 committed
31
    common_init_result llama_init = common_init_from_params(params);
xuxzh1's avatar
init  
xuxzh1 committed
32
33
34
35
36
37

    llama_model * model = llama_init.model;
    llama_context * ctx = llama_init.context;

    // tokenize the prompt
    std::vector<llama_token> inp;
xuxzh1's avatar
update  
xuxzh1 committed
38
39
40
41
42
    inp = common_tokenize(ctx, params.prompt, true, true);

    common_ngram_cache ngram_cache_context;
    common_ngram_cache ngram_cache_dynamic;
    common_ngram_cache ngram_cache_static;
xuxzh1's avatar
init  
xuxzh1 committed
43
44
45
46
47
48
49
50
51

    int64_t t_draft_flat_us = 0;
    int64_t t_draft_us = 0;

    {
        const int64_t t_start_draft_us = ggml_time_us();

        if (!params.lookup_cache_static.empty()) {
            try {
xuxzh1's avatar
update  
xuxzh1 committed
52
                ngram_cache_static = common_ngram_cache_load(params.lookup_cache_static);
xuxzh1's avatar
init  
xuxzh1 committed
53
            } catch (std::ifstream::failure const &) {
xuxzh1's avatar
update  
xuxzh1 committed
54
                LOG_ERR("failed to open static lookup cache: %s", params.lookup_cache_static.c_str());
xuxzh1's avatar
init  
xuxzh1 committed
55
56
57
58
59
60
                exit(1);
            }
        }

        if (!params.lookup_cache_dynamic.empty()) {
            try {
xuxzh1's avatar
update  
xuxzh1 committed
61
                ngram_cache_dynamic = common_ngram_cache_load(params.lookup_cache_dynamic);
xuxzh1's avatar
init  
xuxzh1 committed
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
            } catch (std::ifstream::failure const &) {} // if the file does not exist it will simply be created at the end of the program
        }

        t_draft_flat_us += ggml_time_us() - t_start_draft_us;
    }

    const int n_input = inp.size();
    const int n_ctx = llama_n_ctx(ctx);

    int n_drafted = 0;
    int n_accept  = 0;

    const int64_t t_start_ms = ggml_time_ms();

    // Iterate over input tokens in chunks of size n_ctx.
    // Each chunk is treated as if a sequential generation but with pre-determined tokens to ensure reproducibility.
    for (int i_start = 0; i_start + n_ctx < n_input; i_start += n_ctx) {
        const std::vector<llama_token> inp_slice(inp.begin() + i_start, inp.begin() + i_start + n_ctx);
        std::vector<llama_token> pseudo_output;
        pseudo_output.push_back(inp_slice[0]);

        while ((int) pseudo_output.size() < n_ctx) {
            // Simulate drafting and decoding from draft:
            std::vector<llama_token> draft;
            draft.push_back(pseudo_output.back());

            {
                const int64_t t_start_draft_us = ggml_time_us();
xuxzh1's avatar
update  
xuxzh1 committed
90
                common_ngram_cache_draft(pseudo_output, draft, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, ngram_cache_context, ngram_cache_dynamic, ngram_cache_static);
xuxzh1's avatar
init  
xuxzh1 committed
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
                t_draft_us += ggml_time_us() - t_start_draft_us;
            }

            n_drafted += draft.size() - 1;

            for (size_t j = 1; j < draft.size() && (int) pseudo_output.size() < n_ctx; ++j) {
                const llama_token ground_truth = inp_slice[pseudo_output.size()];
                const llama_token drafted = draft[j];

                if (ground_truth != drafted) {
                    break;
                }

                ++n_accept;
                pseudo_output.push_back(ground_truth);

                {
                    const int64_t t_start_draft_us = ggml_time_us();
xuxzh1's avatar
update  
xuxzh1 committed
109
                    common_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, pseudo_output, 1, false);
xuxzh1's avatar
init  
xuxzh1 committed
110
111
112
113
114
115
116
117
118
                    t_draft_us += ggml_time_us() - t_start_draft_us;
                }
            }

            // After each simulated batch decoding simulate the sampling of a single token:
            if ((int) pseudo_output.size() < n_ctx) {
                pseudo_output.push_back(inp_slice[pseudo_output.size()]);
                {
                    const int64_t t_start_draft_us = ggml_time_us();
xuxzh1's avatar
update  
xuxzh1 committed
119
                    common_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, pseudo_output, 1, false);
xuxzh1's avatar
init  
xuxzh1 committed
120
121
122
123
124
125
126
127
128
129
130
131
132
                    t_draft_us += ggml_time_us() - t_start_draft_us;
                }
            }

            draft.erase(draft.begin());

        }
        if (i_start > 0 && i_start / 100000 != (i_start - n_ctx) / 100000) {
            const int64_t t_now_ms = ggml_time_ms();
            const int64_t eta_ms   = (n_input - i_start) * (t_now_ms - t_start_ms) / i_start;
            const int64_t eta_min  = eta_ms / (60*1000);
            const int64_t eta_s    = (eta_ms - 60*1000*eta_min) / 1000;

xuxzh1's avatar
update  
xuxzh1 committed
133
            LOG_INF("lookup-stats: %d/%d done, ETA: %02" PRId64 ":%02" PRId64 "\n", i_start, n_input, eta_min, eta_s);
xuxzh1's avatar
init  
xuxzh1 committed
134
135
136
        }

        // After each chunk, update the dynamic ngram cache with the context ngram cache:
xuxzh1's avatar
update  
xuxzh1 committed
137
        common_ngram_cache_merge(ngram_cache_dynamic, ngram_cache_context);
xuxzh1's avatar
init  
xuxzh1 committed
138
139
140
        ngram_cache_context.clear();
    }

xuxzh1's avatar
update  
xuxzh1 committed
141
    LOG("\n");
xuxzh1's avatar
init  
xuxzh1 committed
142

xuxzh1's avatar
update  
xuxzh1 committed
143
144
145
146
147
148
    LOG_INF("\n");
    LOG_INF("n_draft      = %d\n", n_draft);
    LOG_INF("n_predict    = %d\n", n_input - n_input % n_ctx);
    LOG_INF("n_drafted    = %d\n", n_drafted);
    LOG_INF("t_draft_flat = %.2f ms\n", t_draft_flat_us*1e-3);
    LOG_INF("t_draft      = %.2f ms, %.2f us per token, %.2f tokens per second\n",
xuxzh1's avatar
init  
xuxzh1 committed
149
            t_draft_us*1e-3, 1.0f*t_draft_us/n_drafted, n_drafted/(1e-6*t_draft_us));
xuxzh1's avatar
update  
xuxzh1 committed
150
151
    LOG_INF("n_accept     = %d\n", n_accept);
    LOG_INF("accept       = %.3f%%\n", 100.0f * n_accept / n_drafted);
xuxzh1's avatar
init  
xuxzh1 committed
152
153
154
155
156
157

    llama_free(ctx);
    llama_free_model(model);

    llama_backend_free();

xuxzh1's avatar
update  
xuxzh1 committed
158
    LOG("\n\n");
xuxzh1's avatar
init  
xuxzh1 committed
159
160
161

    return 0;
}