ngram-cache.cpp 11.2 KB
Newer Older
xuxzh1's avatar
init  
xuxzh1 committed
1
2
3
4
#include "ngram-cache.h"
#include "common.h"
#include "log.h"

xuxzh1's avatar
update  
xuxzh1 committed
5
#include <cinttypes>
xuxzh1's avatar
init  
xuxzh1 committed
6
#include <cstdint>
xuxzh1's avatar
update  
xuxzh1 committed
7
#include <cstdio>
xuxzh1's avatar
init  
xuxzh1 committed
8
#include <fstream>
xuxzh1's avatar
update  
xuxzh1 committed
9
#include <thread>
xuxzh1's avatar
init  
xuxzh1 committed
10

xuxzh1's avatar
update  
xuxzh1 committed
11
void common_ngram_cache_update(common_ngram_cache & ngram_cache, int ngram_min, int ngram_max,
xuxzh1's avatar
init  
xuxzh1 committed
12
13
14
15
16
17
18
19
20
21
22
                              std::vector<llama_token> & inp, int nnew, bool print_progress) {
    const int64_t t_start_ms = ggml_time_ms();
    const int64_t inp_size = inp.size();

    const int64_t n_todo = inp_size * (ngram_max - ngram_min + 1);
    int64_t n_done = 0;

    for (int64_t ngram_size = ngram_min; ngram_size <= ngram_max; ++ngram_size) {
        const int64_t i_start = std::max(inp_size - nnew, ngram_size);
        for (int64_t i = i_start; i < inp_size; ++i) {
            const int64_t ngram_start = i - ngram_size;
xuxzh1's avatar
update  
xuxzh1 committed
23
            common_ngram ngram(&inp[ngram_start], ngram_size);
xuxzh1's avatar
init  
xuxzh1 committed
24
25
            const llama_token token = inp[i];

xuxzh1's avatar
update  
xuxzh1 committed
26
            common_ngram_cache::iterator part_it = ngram_cache.find(ngram);
xuxzh1's avatar
init  
xuxzh1 committed
27
            if (part_it == ngram_cache.end()) {
xuxzh1's avatar
update  
xuxzh1 committed
28
                common_ngram_cache_part part;
xuxzh1's avatar
init  
xuxzh1 committed
29
30
31
                part.emplace(token, 1);
                ngram_cache.emplace(ngram, part);
            } else {
xuxzh1's avatar
update  
xuxzh1 committed
32
                common_ngram_cache_part::iterator token_count_it = part_it->second.find(token);
xuxzh1's avatar
init  
xuxzh1 committed
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
                if (token_count_it == part_it->second.end()) {
                    part_it->second.emplace(token, 1);
                } else {
                    token_count_it->second++;
                }
            }
            ++n_done;

            if (print_progress && n_done % 10000000 == 0) {
                const int64_t t_now_ms = ggml_time_ms();
                const int64_t eta_ms   = (inp_size*(ngram_max-ngram_min+1) - n_done) * (t_now_ms - t_start_ms) / n_done;
                const int64_t eta_min  = eta_ms / (60*1000);
                const int64_t eta_s    = (eta_ms - 60*1000*eta_min) / 1000;

                fprintf(stderr, "%s: %" PRId64 "/%" PRId64 " done, ETA: %02" PRId64 ":%02" PRId64 "\n", __func__, n_done, n_todo, eta_min, eta_s);
            }
        }
    }
}

// Helper function to get a token from the combined, speculative sequence of inp and draft.
static llama_token get_token(const std::vector<llama_token> & inp, const std::vector<llama_token> & draft, const size_t i) {
    return i < inp.size() ? inp[i] : draft[1 + i - inp.size()];
}

// If sample size or percentage are below these thresholds the draft is aborted early:
constexpr int    draft_min_sample_size_lax[LLAMA_NGRAM_MAX] = { 2,  2,  1,  1};
constexpr int        draft_min_percent_lax[LLAMA_NGRAM_MAX] = {66, 50, 50, 50};
constexpr int draft_min_sample_size_strict[LLAMA_NGRAM_MAX] = { 4,  3,  2,  2};
constexpr int     draft_min_percent_strict[LLAMA_NGRAM_MAX] = {75, 66, 66, 66};

// Helper function that tries to draft a token from only the static ngram cache:
xuxzh1's avatar
update  
xuxzh1 committed
65
66
static llama_token try_draft(common_ngram_cache & nc_static, const common_ngram ngram_static) {
    common_ngram_cache::iterator part_static_it = nc_static.find(ngram_static);
xuxzh1's avatar
init  
xuxzh1 committed
67
68
69
    if (part_static_it == nc_static.end()) {
        return -1;
    }
xuxzh1's avatar
update  
xuxzh1 committed
70
    const common_ngram_cache_part part_static = part_static_it->second;
xuxzh1's avatar
init  
xuxzh1 committed
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97

    int max_count_static  = 0;
    int sum_count_static  = 0;
    llama_token max_token = -1;

    for (std::pair<llama_token, int> token_count_static : part_static) {
        const llama_token token = token_count_static.first;
        const int32_t count_static  = token_count_static.second;

        if (count_static > max_count_static) {
            max_token        = token;
            max_count_static = count_static;
        }
        sum_count_static += count_static;
    }

    if (sum_count_static < draft_min_sample_size_lax[LLAMA_NGRAM_STATIC-1]) {
        return -1;
    }
    if (100*max_count_static < draft_min_percent_lax[LLAMA_NGRAM_STATIC-1]*sum_count_static) {
        return -1;
    }
    return max_token;
}

// Try to draft a token from primary cache (context/dynamic), validate with static cache:
static llama_token try_draft(
xuxzh1's avatar
update  
xuxzh1 committed
98
    common_ngram_cache & nc_primary, const std::vector<common_ngram> & ngrams_primary, common_ngram_cache_part & part_static,
xuxzh1's avatar
init  
xuxzh1 committed
99
100
101
102
103
    const int * min_sample_size, const int * min_percent) {

    llama_token drafted_token = -1;

    for (int i = ngrams_primary.size()-1; i >= 0 && drafted_token == -1; --i) {
xuxzh1's avatar
update  
xuxzh1 committed
104
        const common_ngram ngram_primary = ngrams_primary[i];
xuxzh1's avatar
init  
xuxzh1 committed
105

xuxzh1's avatar
update  
xuxzh1 committed
106
        common_ngram_cache::iterator part_primary_it = nc_primary.find(ngram_primary);
xuxzh1's avatar
init  
xuxzh1 committed
107
108
109
        if (part_primary_it == nc_primary.end()) {
            continue;
        }
xuxzh1's avatar
update  
xuxzh1 committed
110
        const common_ngram_cache_part part_primary = part_primary_it->second;
xuxzh1's avatar
init  
xuxzh1 committed
111
112
113
114
115
116
117
118
119

        int max_count_primary = 0;
        int max_count_static  = 0;
        int sum_count_primary = 0;
        llama_token max_token = -1;

        for (std::pair<llama_token, int> token_count_primary : part_primary) {
            const llama_token token = token_count_primary.first;

xuxzh1's avatar
update  
xuxzh1 committed
120
            common_ngram_cache_part::iterator token_count_static_it = part_static.find(token);
xuxzh1's avatar
init  
xuxzh1 committed
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144

            const int32_t count_primary = token_count_primary.second;
            const int32_t count_static  = token_count_static_it != part_static.end() ? 100*token_count_static_it->second : 1;

            if (count_primary*count_static > max_count_primary*max_count_static) {
                max_token         = token;
                max_count_primary = count_primary;
                max_count_static  = count_static;
            }
            sum_count_primary += count_primary;
        }

        if (sum_count_primary < min_sample_size[i]) {
            continue;
        }
        if (100*max_count_primary < min_percent[i]*sum_count_primary) {
            continue;;
        }
        drafted_token = max_token;
    }

    return drafted_token;
}

xuxzh1's avatar
update  
xuxzh1 committed
145
void common_ngram_cache_draft(
xuxzh1's avatar
init  
xuxzh1 committed
146
    std::vector<llama_token> & inp, std::vector<llama_token> & draft, int n_draft, int ngram_min, int ngram_max,
xuxzh1's avatar
update  
xuxzh1 committed
147
    common_ngram_cache & nc_context, common_ngram_cache & nc_dynamic, common_ngram_cache & nc_static
xuxzh1's avatar
init  
xuxzh1 committed
148
149
150
151
152
153
154
155
156
157
158
159
) {
    GGML_ASSERT(draft.size() == 1);
    const int inp_size = inp.size();

    if (inp_size < LLAMA_NGRAM_STATIC) {
        return;
    }

    while ((int) draft.size()-1 < n_draft) {
        llama_token drafted_token = -1;

        const int ngram_start_static = inp_size-LLAMA_NGRAM_STATIC + draft.size()-1;
xuxzh1's avatar
update  
xuxzh1 committed
160
        common_ngram ngram_static;
xuxzh1's avatar
init  
xuxzh1 committed
161
162
163
        for (int j = ngram_start_static; j < ngram_start_static + LLAMA_NGRAM_STATIC; ++j) {
            ngram_static.tokens[j-ngram_start_static] = get_token(inp, draft, j);
        }
xuxzh1's avatar
update  
xuxzh1 committed
164
165
        common_ngram_cache::iterator part_static_it = nc_static.find(ngram_static);
        common_ngram_cache_part part_static;
xuxzh1's avatar
init  
xuxzh1 committed
166
167
168
169
170
        if (part_static_it != nc_static.end()) {
            part_static = part_static_it->second;
        }

        // cd = context + dynamic
xuxzh1's avatar
update  
xuxzh1 committed
171
        std::vector<common_ngram> ngrams_cd;
xuxzh1's avatar
init  
xuxzh1 committed
172
173
        for (int ngram_size_cd = ngram_min; ngram_size_cd <= ngram_max; ++ngram_size_cd) {
            const int ngram_start_cd = inp_size-ngram_size_cd + draft.size()-1;
xuxzh1's avatar
update  
xuxzh1 committed
174
            common_ngram ngram_cd;
xuxzh1's avatar
init  
xuxzh1 committed
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
            for (int j = ngram_start_cd; j < ngram_start_cd + ngram_size_cd; ++j) {
                ngram_cd.tokens[j-ngram_start_cd] = get_token(inp, draft, j);
            }
            ngrams_cd.push_back(ngram_cd);
        }
        if (drafted_token == -1) {
            drafted_token = try_draft(nc_context, ngrams_cd, part_static, draft_min_sample_size_lax, draft_min_percent_lax);
        }
        if (drafted_token == -1) {
            drafted_token = try_draft(nc_dynamic, ngrams_cd, part_static, draft_min_sample_size_strict, draft_min_percent_strict);
        }
        if (drafted_token == -1) {
            drafted_token = try_draft(nc_static, ngram_static);
        }

        if (drafted_token == -1) {
            break;
        }

        LOG(" - draft candidate: token=%d\n", drafted_token);
        draft.push_back(drafted_token);
    }
}

xuxzh1's avatar
update  
xuxzh1 committed
199
void common_ngram_cache_save(common_ngram_cache & ngram_cache, std::string & filename) {
xuxzh1's avatar
init  
xuxzh1 committed
200
    std::ofstream file_out(filename, std::ios::binary);
xuxzh1's avatar
update  
xuxzh1 committed
201
202
203
    for (std::pair<common_ngram, common_ngram_cache_part> item : ngram_cache) {
        const common_ngram      ngram        = item.first;
        common_ngram_cache_part token_counts = item.second;
xuxzh1's avatar
init  
xuxzh1 committed
204
205
206
207
        GGML_ASSERT(!token_counts.empty());
        const int32_t ntokens = token_counts.size();
        GGML_ASSERT(ntokens > 0);

xuxzh1's avatar
update  
xuxzh1 committed
208
        file_out.write(reinterpret_cast<const char *>(&ngram),   sizeof(common_ngram));
xuxzh1's avatar
init  
xuxzh1 committed
209
210
211
212
213
214
215
216
217
218
219
220
221
        file_out.write(reinterpret_cast<const char *>(&ntokens), sizeof(int32_t));
        for (std::pair<llama_token, int32_t> item2 : token_counts) {
            const llama_token token = item2.first;
            const int32_t     count = item2.second;
            GGML_ASSERT(count > 0);

            file_out.write(reinterpret_cast<const char *>(&token), sizeof(llama_token));
            file_out.write(reinterpret_cast<const char *>(&count), sizeof(int32_t));
        }
    }

}

xuxzh1's avatar
update  
xuxzh1 committed
222
common_ngram_cache common_ngram_cache_load(std::string & filename) {
xuxzh1's avatar
init  
xuxzh1 committed
223
224
225
226
    std::ifstream hashmap_file(filename, std::ios::binary);
    if (!hashmap_file) {
        throw std::ifstream::failure("Unable to open file " + filename);
    }
xuxzh1's avatar
update  
xuxzh1 committed
227
    common_ngram_cache ngram_cache;
xuxzh1's avatar
init  
xuxzh1 committed
228

xuxzh1's avatar
update  
xuxzh1 committed
229
    common_ngram ngram;
xuxzh1's avatar
init  
xuxzh1 committed
230
231
232
233
234
235
236
237
    int32_t     ntokens;
    llama_token token;
    int32_t     count;

    char * ngramc   = reinterpret_cast<char*>(&ngram);
    char * ntokensc = reinterpret_cast<char*>(&ntokens);
    char * tokenc   = reinterpret_cast<char*>(&token);
    char * countc   = reinterpret_cast<char*>(&count);
xuxzh1's avatar
update  
xuxzh1 committed
238
    while(hashmap_file.read(ngramc, sizeof(common_ngram))) {
xuxzh1's avatar
init  
xuxzh1 committed
239
240
241
        GGML_ASSERT(!hashmap_file.eof());
        GGML_ASSERT(hashmap_file.read(ntokensc, sizeof(int32_t)));
        GGML_ASSERT(ntokens > 0);
xuxzh1's avatar
update  
xuxzh1 committed
242
        common_ngram_cache_part token_counts;
xuxzh1's avatar
init  
xuxzh1 committed
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259

        for (int i = 0; i < ntokens; ++i) {
            GGML_ASSERT(!hashmap_file.eof());
            GGML_ASSERT(hashmap_file.read(tokenc, sizeof(llama_token)));
            GGML_ASSERT(!hashmap_file.eof());
            GGML_ASSERT(hashmap_file.read(countc, sizeof(int32_t)));
            GGML_ASSERT(count > 0);
            token_counts.emplace(token, count);
        }

        ngram_cache.emplace(ngram, token_counts);
    }
    GGML_ASSERT(hashmap_file.eof());

    return ngram_cache;
}

xuxzh1's avatar
update  
xuxzh1 committed
260
261
262
263
void common_ngram_cache_merge(common_ngram_cache & ngram_cache_target, common_ngram_cache & ngram_cache_add) {
    for (std::pair<common_ngram, common_ngram_cache_part> ngram_part : ngram_cache_add) {
        const common_ngram      ngram = ngram_part.first;
        common_ngram_cache_part  part = ngram_part.second;
xuxzh1's avatar
init  
xuxzh1 committed
264

xuxzh1's avatar
update  
xuxzh1 committed
265
        common_ngram_cache::iterator part_merged_it = ngram_cache_target.find(ngram);
xuxzh1's avatar
init  
xuxzh1 committed
266
267
268
269
270
271
272
273
274
275
        if (part_merged_it == ngram_cache_target.end()) {
            ngram_cache_target.emplace(ngram, part);
            continue;
        }

        for (std::pair<llama_token, int32_t> token_count : part) {
            const llama_token token = token_count.first;
            const int32_t     count = token_count.second;
            GGML_ASSERT(count > 0);

xuxzh1's avatar
update  
xuxzh1 committed
276
            common_ngram_cache_part::iterator token_count_merged_it = part_merged_it->second.find(token);
xuxzh1's avatar
init  
xuxzh1 committed
277
278
279
280
281
282
283
284
285
            if (token_count_merged_it == part_merged_it->second.end()) {
                part_merged_it->second.emplace(token, count);
                continue;
            }

            token_count_merged_it->second += count;
        }
    }
}