sampling.cpp 17.5 KB
Newer Older
1
2
#include "sampling.h"

3
#include "common.h"
4

5
6
#include <cmath>
#include <unordered_map>
7

8
9
10
11
12
// the ring buffer works similarly to std::deque, but with a fixed capacity
// TODO: deduplicate with llama-impl.h
template<typename T>
struct ring_buffer {
    ring_buffer(size_t cap) : capacity(cap), data(cap) {}
13

14
15
16
    T & front() {
        if (sz == 0) {
            throw std::runtime_error("ring buffer is empty");
17
        }
18
19
        return data[first];
    }
20

21
22
23
    const T & front() const {
        if (sz == 0) {
            throw std::runtime_error("ring buffer is empty");
24
        }
25
26
        return data[first];
    }
27

28
29
30
    T & back() {
        if (sz == 0) {
            throw std::runtime_error("ring buffer is empty");
31
        }
32
        return data[pos];
33
34
    }

35
36
37
38
39
    const T & back() const {
        if (sz == 0) {
            throw std::runtime_error("ring buffer is empty");
        }
        return data[pos];
40
41
    }

42
43
44
45
46
47
48
49
50
    void push_back(const T & value) {
        if (sz == capacity) {
            // advance the start when buffer is full
            first = (first + 1) % capacity;
        } else {
            sz++;
        }
        data[pos] = value;
        pos = (pos + 1) % capacity;
51
52
    }

53
54
55
56
57
58
59
60
61
    T pop_front() {
        if (sz == 0) {
            throw std::runtime_error("ring buffer is empty");
        }
        T value = data[first];
        first = (first + 1) % capacity;
        sz--;
        return value;
    }
62

63
64
65
    const T & rat(size_t i) const {
        if (i >= sz) {
            throw std::runtime_error("ring buffer: index out of bounds");
66
        }
67
        return data[(first + sz - i - 1) % capacity];
68
69
    }

70
71
72
73
74
75
76
77
    std::vector<T> to_vector() const {
        std::vector<T> result;
        result.reserve(sz);
        for (size_t i = 0; i < sz; i++) {
            result.push_back(data[(first + i) % capacity]);
        }
        return result;
    }
78

79
80
81
82
83
    void clear() {
        // here only reset the status of the buffer
        sz = 0;
        first = 0;
        pos = 0;
84
85
    }

86
87
    bool empty() const {
        return sz == 0;
88
89
    }

90
91
    size_t size() const {
        return sz;
92
93
    }

94
95
96
97
98
99
    size_t capacity = 0;
    size_t sz = 0;
    size_t first = 0;
    size_t pos = 0;
    std::vector<T> data;
};
100

101
102
struct common_sampler {
    common_params_sampling params;
103

104
105
    struct llama_sampler * grmr;
    struct llama_sampler * chain;
106

107
    ring_buffer<llama_token> prev;
108

109
    std::vector<llama_token_data> cur;
110

111
    llama_token_data_array cur_p;
112

113
114
    void set_logits(struct llama_context * ctx, int idx) {
        const auto * logits = llama_get_logits_ith(ctx, idx);
115

116
117
118
119
120
121
122
123
124
125
126
127
        const int n_vocab = llama_n_vocab(llama_get_model(ctx));

        cur.resize(n_vocab);

        for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
            cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
        }

        cur_p = { cur.data(), cur.size(), -1, false };
    }
};

128
std::string common_params_sampling::print() const {
129
130
131
132
    char result[1024];

    snprintf(result, sizeof(result),
            "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
133
134
            "\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n"
            "\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, temp = %.3f\n"
135
            "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
136
            penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
137
138
            dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n,
            top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, temp,
139
            mirostat, mirostat_eta, mirostat_tau);
140
141
142
143

    return std::string(result);
}

144
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) {
145
146
147
148
    llama_sampler_chain_params lparams = llama_sampler_chain_default_params();

    lparams.no_perf = params.no_perf;

149
    auto * result = new common_sampler {
150
151
152
153
154
155
156
157
158
159
160
161
162
163
        /* .params = */ params,
        /* .grmr   = */ llama_sampler_init_grammar(model, params.grammar.c_str(), "root"),
        /* .chain  = */ llama_sampler_chain_init(lparams),
        /* .prev   = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
        /* .cur    = */ {},
        /* .cur_p  = */ {},
    };

    llama_sampler_chain_add(result->chain,
            llama_sampler_init_logit_bias(
                llama_n_vocab(model),
                params.logit_bias.size(),
                params.logit_bias.data()));

164
165
166
    if (params.mirostat == 0) {
        for (const auto & cnstr : params.samplers) {
            switch (cnstr) {
167
                case COMMON_SAMPLER_TYPE_DRY:
168
                    {
169
                        std::vector<const char *> c_breakers;
170
                        c_breakers.reserve(params.dry_sequence_breakers.size());
171
                        for (const auto & str : params.dry_sequence_breakers) {
172
173
174
175
176
                            c_breakers.push_back(str.c_str());
                        }

                        llama_sampler_chain_add(result->chain, llama_sampler_init_dry      (model, params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
                    }
177
                    break;
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
                case COMMON_SAMPLER_TYPE_TOP_K:
                    llama_sampler_chain_add(result->chain, llama_sampler_init_top_k    (params.top_k));
                    break;
                case COMMON_SAMPLER_TYPE_TOP_P:
                    llama_sampler_chain_add(result->chain, llama_sampler_init_top_p    (params.top_p, params.min_keep));
                    break;
                case COMMON_SAMPLER_TYPE_MIN_P:
                    llama_sampler_chain_add(result->chain, llama_sampler_init_min_p    (params.min_p, params.min_keep));
                    break;
                case COMMON_SAMPLER_TYPE_XTC:
                    llama_sampler_chain_add(result->chain, llama_sampler_init_xtc      (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
                    break;
                case COMMON_SAMPLER_TYPE_TYPICAL_P:
                    llama_sampler_chain_add(result->chain, llama_sampler_init_typical  (params.typ_p, params.min_keep));
                    break;
                case COMMON_SAMPLER_TYPE_TEMPERATURE:
                    llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
                    break;
                case COMMON_SAMPLER_TYPE_INFILL:
                    llama_sampler_chain_add(result->chain, llama_sampler_init_infill   (model));
                    break;
199
200
201
                case COMMON_SAMPLER_TYPE_PENALTIES:
                    llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
                    break;
202
203
                default:
                    GGML_ASSERT(false && "unknown sampler type");
204
205
            }
        }
206
207
208
209
210
211
212
        llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
    } else if (params.mirostat == 1) {
        llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
        llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_n_vocab(model), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
    } else if (params.mirostat == 2) {
        llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
        llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
213
    } else {
214
        GGML_ASSERT(false && "unknown mirostat version");
215
216
217
218
219
    }

    return result;
}

220
void common_sampler_free(struct common_sampler * gsmpl) {
221
222
223
224
225
226
    if (gsmpl) {
        llama_sampler_free(gsmpl->grmr);

        llama_sampler_free(gsmpl->chain);

        delete gsmpl;
227
228
229
    }
}

230
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
231
232
233
    if (accept_grammar) {
        llama_sampler_accept(gsmpl->grmr, token);
    }
234

235
    llama_sampler_accept(gsmpl->chain, token);
236

237
    gsmpl->prev.push_back(token);
238
239
}

240
void common_sampler_reset(struct common_sampler * gsmpl) {
241
    llama_sampler_reset(gsmpl->grmr);
242

243
    llama_sampler_reset(gsmpl->chain);
244
245
}

246
247
struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
    return new common_sampler {
248
249
250
251
252
253
254
        /* .params = */ gsmpl->params,
        /* .grmr   = */ llama_sampler_clone(gsmpl->grmr),
        /* .chain  = */ llama_sampler_clone(gsmpl->chain),
        /* .prev   = */ gsmpl->prev,
        /* .cur    = */ gsmpl->cur,
        /* .cur_p  = */ gsmpl->cur_p,
    };
255
256
}

257
void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl) {
258
    // TODO: measure grammar performance
259

260
261
262
263
264
265
266
    if (gsmpl) {
        llama_perf_sampler_print(gsmpl->chain);
    }
    if (ctx) {
        llama_perf_context_print(ctx);
    }
}
267

268
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
269
    gsmpl->set_logits(ctx, idx);
270

271
272
273
    auto & grmr  = gsmpl->grmr;
    auto & chain = gsmpl->chain;
    auto & cur_p = gsmpl->cur_p; // initialized by set_logits
274

275
276
    if (grammar_first) {
        llama_sampler_apply(grmr, &cur_p);
277
278
    }

279
    llama_sampler_apply(chain, &cur_p);
280

281
    GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
282

283
    const llama_token id = cur_p.data[cur_p.selected].id;
284

285
286
287
    if (grammar_first) {
        return id;
    }
288

289
290
291
292
    // check if it the sampled token fits the grammar
    {
        llama_token_data       single_token_data       = { id, 1.0f, 0.0f };
        llama_token_data_array single_token_data_array = { &single_token_data, 1, -1, false };
293

294
        llama_sampler_apply(grmr, &single_token_data_array);
295

296
297
298
        const bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
        if (is_valid) {
            return id;
299
300
301
        }
    }

302
303
304
    // resampling:
    // if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain
    gsmpl->set_logits(ctx, idx);
305

306
307
    llama_sampler_apply(grmr,  &cur_p);
    llama_sampler_apply(chain, &cur_p);
308

309
    GGML_ASSERT(cur_p.selected != -1 && "no selected token during re-sampling - check your sampling configuration");
310

311
312
    return cur_p.data[cur_p.selected].id;
}
313

314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first) {
    GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");

    std::vector<llama_token> result;
    result.reserve(idxs.size());

    size_t i = 0;
    for (; i < draft.size(); i++) {
        const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);

        common_sampler_accept(gsmpl, id, true);

        result.push_back(id);

        if (draft[i] != id) {
            break;
        }
    }

    if (i == draft.size()) {
        const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);

        common_sampler_accept(gsmpl, id, true);

        result.push_back(id);
    }

    return result;
}

std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) {
    std::vector<int> idxs(draft.size() + 1);
    for (size_t i = 0; i < idxs.size(); ++i) {
        idxs[i] = i;
    }

    return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first);
}

uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
354
355
356
357
    return llama_sampler_get_seed(gsmpl->chain);
}

// helpers
358

359
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl) {
360
361
    return &gsmpl->cur_p;
}
362

363
llama_token common_sampler_last(const struct common_sampler * gsmpl) {
364
365
    return gsmpl->prev.rat(0);
}
366

367
std::string common_sampler_print(const struct common_sampler * gsmpl) {
368
    std::string result = "logits ";
369

370
371
372
    for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) {
        const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
        result += std::string("-> ") + llama_sampler_name(smpl) + " ";
373
374
    }

375
376
377
    return result;
}

378
std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx_main, int n) {
379
380
381
382
    n = std::min(n, (int) gsmpl->prev.size());

    if (n <= 0) {
        return "";
383
384
    }

385
386
387
388
389
390
391
392
    std::string result;
    result.reserve(8*n); // 8 is the average length of a token [citation needed], TODO: compute this from the vocab

    for (int i = n - 1; i >= 0; i--) {
        const llama_token id = gsmpl->prev.rat(i);

        GGML_ASSERT(id != LLAMA_TOKEN_NULL && "null token in the sampling history - should not happen");

393
        result += common_token_to_piece(ctx_main, id);
394
395
    }

396
397
398
    return result;
}

399
char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
400
    switch (cnstr) {
401
402
403
404
405
406
407
408
        case COMMON_SAMPLER_TYPE_DRY:         return 'd';
        case COMMON_SAMPLER_TYPE_TOP_K:       return 'k';
        case COMMON_SAMPLER_TYPE_TYPICAL_P:   return 'y';
        case COMMON_SAMPLER_TYPE_TOP_P:       return 'p';
        case COMMON_SAMPLER_TYPE_MIN_P:       return 'm';
        case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't';
        case COMMON_SAMPLER_TYPE_XTC:         return 'x';
        case COMMON_SAMPLER_TYPE_INFILL:      return 'i';
409
        case COMMON_SAMPLER_TYPE_PENALTIES:   return 'e';
410
411
412
        default : return '?';
    }
}
413

414
std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
415
    switch (cnstr) {
416
417
418
419
420
421
422
423
        case COMMON_SAMPLER_TYPE_DRY:         return "dry";
        case COMMON_SAMPLER_TYPE_TOP_K:       return "top_k";
        case COMMON_SAMPLER_TYPE_TYPICAL_P:   return "typ_p";
        case COMMON_SAMPLER_TYPE_TOP_P:       return "top_p";
        case COMMON_SAMPLER_TYPE_MIN_P:       return "min_p";
        case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature";
        case COMMON_SAMPLER_TYPE_XTC:         return "xtc";
        case COMMON_SAMPLER_TYPE_INFILL:      return "infill";
424
        case COMMON_SAMPLER_TYPE_PENALTIES:   return "penalties";
425
        default : return "";
426
    }
427
}
428

429
430
431
432
433
434
435
436
437
438
std::vector<common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names) {
    std::unordered_map<std::string, common_sampler_type> sampler_canonical_name_map {
        { "dry",         COMMON_SAMPLER_TYPE_DRY },
        { "top_k",       COMMON_SAMPLER_TYPE_TOP_K },
        { "top_p",       COMMON_SAMPLER_TYPE_TOP_P },
        { "typ_p",       COMMON_SAMPLER_TYPE_TYPICAL_P },
        { "min_p",       COMMON_SAMPLER_TYPE_MIN_P },
        { "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
        { "xtc",         COMMON_SAMPLER_TYPE_XTC },
        { "infill",      COMMON_SAMPLER_TYPE_INFILL },
439
        { "penalties",   COMMON_SAMPLER_TYPE_PENALTIES },
440
    };
441

442
443
    // since samplers names are written multiple ways
    // make it ready for both system names and input names
444
445
446
447
448
449
450
451
452
453
    std::unordered_map<std::string, common_sampler_type> sampler_alt_name_map {
        { "top-k",       COMMON_SAMPLER_TYPE_TOP_K },
        { "top-p",       COMMON_SAMPLER_TYPE_TOP_P },
        { "nucleus",     COMMON_SAMPLER_TYPE_TOP_P },
        { "typical-p",   COMMON_SAMPLER_TYPE_TYPICAL_P },
        { "typical",     COMMON_SAMPLER_TYPE_TYPICAL_P },
        { "typ-p",       COMMON_SAMPLER_TYPE_TYPICAL_P },
        { "typ",         COMMON_SAMPLER_TYPE_TYPICAL_P },
        { "min-p",       COMMON_SAMPLER_TYPE_MIN_P },
        { "temp",        COMMON_SAMPLER_TYPE_TEMPERATURE },
454
    };
455

456
    std::vector<common_sampler_type> samplers;
457
    samplers.reserve(names.size());
458

459
460
461
462
463
464
465
466
467
    for (const auto & name : names) {
        auto sampler = sampler_canonical_name_map.find(name);
        if (sampler != sampler_canonical_name_map.end()) {
            samplers.push_back(sampler->second);
        } else {
            if (allow_alt_names) {
                sampler = sampler_alt_name_map.find(name);
                if (sampler != sampler_alt_name_map.end()) {
                    samplers.push_back(sampler->second);
468
469
470
471
472
                }
            }
        }
    }

473
    return samplers;
474
475
}

476
477
478
479
480
481
482
483
484
485
std::vector<common_sampler_type> common_sampler_types_from_chars(const std::string & chars) {
    std::unordered_map<char, common_sampler_type> sampler_name_map = {
        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_DRY),         COMMON_SAMPLER_TYPE_DRY },
        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_K),       COMMON_SAMPLER_TYPE_TOP_K },
        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TYPICAL_P),   COMMON_SAMPLER_TYPE_TYPICAL_P },
        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_P),       COMMON_SAMPLER_TYPE_TOP_P },
        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_MIN_P),       COMMON_SAMPLER_TYPE_MIN_P },
        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE },
        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC),         COMMON_SAMPLER_TYPE_XTC },
        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL),      COMMON_SAMPLER_TYPE_INFILL },
486
        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_PENALTIES),   COMMON_SAMPLER_TYPE_PENALTIES },
487
    };
488

489
    std::vector<common_sampler_type> samplers;
490
    samplers.reserve(chars.size());
491

492
493
494
495
496
    for (const auto & c : chars) {
        const auto sampler = sampler_name_map.find(c);
        if (sampler != sampler_name_map.end()) {
            samplers.push_back(sampler->second);
        }
497
    }
498
499

    return samplers;
500
}