sampling.cpp 19.2 KB
Newer Older
1
2
#include "sampling.h"

3
#include "common.h"
4

5
6
#include <cmath>
#include <unordered_map>
7

8
9
10
11
12
// the ring buffer works similarly to std::deque, but with a fixed capacity
// TODO: deduplicate with llama-impl.h
template<typename T>
struct ring_buffer {
    ring_buffer(size_t cap) : capacity(cap), data(cap) {}
13

14
15
16
    T & front() {
        if (sz == 0) {
            throw std::runtime_error("ring buffer is empty");
17
        }
18
19
        return data[first];
    }
20

21
22
23
    const T & front() const {
        if (sz == 0) {
            throw std::runtime_error("ring buffer is empty");
24
        }
25
26
        return data[first];
    }
27

28
29
30
    T & back() {
        if (sz == 0) {
            throw std::runtime_error("ring buffer is empty");
31
        }
32
        return data[pos];
33
34
    }

35
36
37
38
39
    const T & back() const {
        if (sz == 0) {
            throw std::runtime_error("ring buffer is empty");
        }
        return data[pos];
40
41
    }

42
43
44
45
46
47
48
49
50
    void push_back(const T & value) {
        if (sz == capacity) {
            // advance the start when buffer is full
            first = (first + 1) % capacity;
        } else {
            sz++;
        }
        data[pos] = value;
        pos = (pos + 1) % capacity;
51
52
    }

53
54
55
56
57
58
59
60
61
    T pop_front() {
        if (sz == 0) {
            throw std::runtime_error("ring buffer is empty");
        }
        T value = data[first];
        first = (first + 1) % capacity;
        sz--;
        return value;
    }
62

63
64
65
    const T & rat(size_t i) const {
        if (i >= sz) {
            throw std::runtime_error("ring buffer: index out of bounds");
66
        }
67
        return data[(first + sz - i - 1) % capacity];
68
69
    }

70
71
72
73
74
75
76
77
    std::vector<T> to_vector() const {
        std::vector<T> result;
        result.reserve(sz);
        for (size_t i = 0; i < sz; i++) {
            result.push_back(data[(first + i) % capacity]);
        }
        return result;
    }
78

79
80
81
82
83
    void clear() {
        // here only reset the status of the buffer
        sz = 0;
        first = 0;
        pos = 0;
84
85
    }

86
87
    bool empty() const {
        return sz == 0;
88
89
    }

90
91
    size_t size() const {
        return sz;
92
93
    }

94
95
96
97
98
99
    size_t capacity = 0;
    size_t sz = 0;
    size_t first = 0;
    size_t pos = 0;
    std::vector<T> data;
};
100

101
102
struct common_sampler {
    common_params_sampling params;
103

104
105
    struct llama_sampler * grmr;
    struct llama_sampler * chain;
106

107
    ring_buffer<llama_token> prev;
108

109
    std::vector<llama_token_data> cur;
110

111
    llama_token_data_array cur_p;
112

113
114
    void set_logits(struct llama_context * ctx, int idx) {
        const auto * logits = llama_get_logits_ith(ctx, idx);
115

116
117
118
119
        const llama_model * model = llama_get_model(ctx);
        const llama_vocab * vocab = llama_model_get_vocab(model);

        const int n_vocab = llama_vocab_n_tokens(vocab);
120
121
122
123
124
125
126
127
128
129
130

        cur.resize(n_vocab);

        for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
            cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
        }

        cur_p = { cur.data(), cur.size(), -1, false };
    }
};

131
std::string common_params_sampling::print() const {
132
133
134
135
    char result[1024];

    snprintf(result, sizeof(result),
            "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
136
            "\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n"
137
            "\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, top_n_sigma = %.3f, temp = %.3f\n"
138
            "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
139
            penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
140
            dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n,
141
            top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, top_n_sigma, temp,
142
            mirostat, mirostat_eta, mirostat_tau);
143
144
145
146

    return std::string(result);
}

147
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) {
148
149
    const llama_vocab * vocab = llama_model_get_vocab(model);

150
151
152
153
    llama_sampler_chain_params lparams = llama_sampler_chain_default_params();

    lparams.no_perf = params.no_perf;

154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
    struct llama_sampler * grmr;
    if (params.grammar.compare(0, 11, "%llguidance") == 0) {
#ifdef LLAMA_USE_LLGUIDANCE
        grmr = llama_sampler_init_llg(vocab, "lark", params.grammar.c_str());
#else
        GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
#endif // LLAMA_USE_LLGUIDANCE
    } else {
        std::vector<const char *> trigger_words;
        trigger_words.reserve(params.grammar_trigger_words.size());
        for (const auto & str : params.grammar_trigger_words) {
            trigger_words.push_back(str.word.c_str());
        }

        grmr = params.grammar_lazy
             ? llama_sampler_init_grammar_lazy(vocab, params.grammar.c_str(), "root",
                                               trigger_words.data(), trigger_words.size(),
                                               params.grammar_trigger_tokens.data(), params.grammar_trigger_tokens.size())
             :      llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
    }

175
    auto * result = new common_sampler {
176
        /* .params = */ params,
177
        /* .grmr   = */ grmr,
178
179
180
181
182
183
184
185
        /* .chain  = */ llama_sampler_chain_init(lparams),
        /* .prev   = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
        /* .cur    = */ {},
        /* .cur_p  = */ {},
    };

    llama_sampler_chain_add(result->chain,
            llama_sampler_init_logit_bias(
186
                llama_vocab_n_tokens(vocab),
187
188
189
                params.logit_bias.size(),
                params.logit_bias.data()));

190
    if (params.mirostat == 0) {
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
        if (params.top_n_sigma >= 0) {
            llama_sampler_chain_add(result->chain, llama_sampler_init_top_k        (params.top_k));
            llama_sampler_chain_add(result->chain, llama_sampler_init_temp         (params.temp));
            llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma  (params.top_n_sigma));
        } else {
            for (const auto & cnstr : params.samplers) {
                switch (cnstr) {
                    case COMMON_SAMPLER_TYPE_DRY:
                        {
                            std::vector<const char *> c_breakers;
                            c_breakers.reserve(params.dry_sequence_breakers.size());
                            for (const auto & str : params.dry_sequence_breakers) {
                                c_breakers.push_back(str.c_str());
                            }

                            llama_sampler_chain_add(result->chain, llama_sampler_init_dry      (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
207
                        }
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
                        break;
                    case COMMON_SAMPLER_TYPE_TOP_K:
                        llama_sampler_chain_add(result->chain, llama_sampler_init_top_k    (params.top_k));
                        break;
                    case COMMON_SAMPLER_TYPE_TOP_P:
                        llama_sampler_chain_add(result->chain, llama_sampler_init_top_p    (params.top_p, params.min_keep));
                        break;
                    case COMMON_SAMPLER_TYPE_MIN_P:
                        llama_sampler_chain_add(result->chain, llama_sampler_init_min_p    (params.min_p, params.min_keep));
                        break;
                    case COMMON_SAMPLER_TYPE_XTC:
                        llama_sampler_chain_add(result->chain, llama_sampler_init_xtc      (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
                        break;
                    case COMMON_SAMPLER_TYPE_TYPICAL_P:
                        llama_sampler_chain_add(result->chain, llama_sampler_init_typical  (params.typ_p, params.min_keep));
                        break;
                    case COMMON_SAMPLER_TYPE_TEMPERATURE:
                        llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
                        break;
                    case COMMON_SAMPLER_TYPE_INFILL:
                        llama_sampler_chain_add(result->chain, llama_sampler_init_infill   (vocab));
                        break;
                    case COMMON_SAMPLER_TYPE_PENALTIES:
                        llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
                        break;
                    default:
                        GGML_ASSERT(false && "unknown sampler type");
                }
236
237
            }
        }
238
239
240
        llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
    } else if (params.mirostat == 1) {
        llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
241
        llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
242
243
244
    } else if (params.mirostat == 2) {
        llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
        llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
245
    } else {
246
        GGML_ASSERT(false && "unknown mirostat version");
247
248
249
250
251
    }

    return result;
}

252
void common_sampler_free(struct common_sampler * gsmpl) {
253
254
255
256
257
258
    if (gsmpl) {
        llama_sampler_free(gsmpl->grmr);

        llama_sampler_free(gsmpl->chain);

        delete gsmpl;
259
260
261
    }
}

262
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
263
264
265
    if (accept_grammar) {
        llama_sampler_accept(gsmpl->grmr, token);
    }
266

267
    llama_sampler_accept(gsmpl->chain, token);
268

269
    gsmpl->prev.push_back(token);
270
271
}

272
void common_sampler_reset(struct common_sampler * gsmpl) {
273
    llama_sampler_reset(gsmpl->grmr);
274

275
    llama_sampler_reset(gsmpl->chain);
276
277
}

278
279
struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
    return new common_sampler {
280
281
282
283
284
285
286
        /* .params = */ gsmpl->params,
        /* .grmr   = */ llama_sampler_clone(gsmpl->grmr),
        /* .chain  = */ llama_sampler_clone(gsmpl->chain),
        /* .prev   = */ gsmpl->prev,
        /* .cur    = */ gsmpl->cur,
        /* .cur_p  = */ gsmpl->cur_p,
    };
287
288
}

289
void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl) {
290
    // TODO: measure grammar performance
291

292
293
294
295
296
297
298
    if (gsmpl) {
        llama_perf_sampler_print(gsmpl->chain);
    }
    if (ctx) {
        llama_perf_context_print(ctx);
    }
}
299

300
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
301
    gsmpl->set_logits(ctx, idx);
302

303
304
305
    auto & grmr  = gsmpl->grmr;
    auto & chain = gsmpl->chain;
    auto & cur_p = gsmpl->cur_p; // initialized by set_logits
306

307
308
    if (grammar_first) {
        llama_sampler_apply(grmr, &cur_p);
309
310
    }

311
    llama_sampler_apply(chain, &cur_p);
312

313
    GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
314

315
    const llama_token id = cur_p.data[cur_p.selected].id;
316

317
318
319
    if (grammar_first) {
        return id;
    }
320

321
322
323
324
    // check if it the sampled token fits the grammar
    {
        llama_token_data       single_token_data       = { id, 1.0f, 0.0f };
        llama_token_data_array single_token_data_array = { &single_token_data, 1, -1, false };
325

326
        llama_sampler_apply(grmr, &single_token_data_array);
327

328
329
330
        const bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
        if (is_valid) {
            return id;
331
332
333
        }
    }

334
335
336
    // resampling:
    // if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain
    gsmpl->set_logits(ctx, idx);
337

338
339
    llama_sampler_apply(grmr,  &cur_p);
    llama_sampler_apply(chain, &cur_p);
340

341
    GGML_ASSERT(cur_p.selected != -1 && "no selected token during re-sampling - check your sampling configuration");
342

343
344
    return cur_p.data[cur_p.selected].id;
}
345

346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first) {
    GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");

    std::vector<llama_token> result;
    result.reserve(idxs.size());

    size_t i = 0;
    for (; i < draft.size(); i++) {
        const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);

        common_sampler_accept(gsmpl, id, true);

        result.push_back(id);

        if (draft[i] != id) {
            break;
        }
    }

    if (i == draft.size()) {
        const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);

        common_sampler_accept(gsmpl, id, true);

        result.push_back(id);
    }

    return result;
}

std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) {
    std::vector<int> idxs(draft.size() + 1);
    for (size_t i = 0; i < idxs.size(); ++i) {
        idxs[i] = i;
    }

    return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first);
}

uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
386
387
388
389
    return llama_sampler_get_seed(gsmpl->chain);
}

// helpers
390

391
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl) {
392
393
    return &gsmpl->cur_p;
}
394

395
llama_token common_sampler_last(const struct common_sampler * gsmpl) {
396
397
    return gsmpl->prev.rat(0);
}
398

399
std::string common_sampler_print(const struct common_sampler * gsmpl) {
400
    std::string result = "logits ";
401

402
403
404
    for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) {
        const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
        result += std::string("-> ") + llama_sampler_name(smpl) + " ";
405
406
    }

407
408
409
    return result;
}

410
std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx_main, int n) {
411
412
413
414
    n = std::min(n, (int) gsmpl->prev.size());

    if (n <= 0) {
        return "";
415
416
    }

417
418
419
420
421
422
423
424
    std::string result;
    result.reserve(8*n); // 8 is the average length of a token [citation needed], TODO: compute this from the vocab

    for (int i = n - 1; i >= 0; i--) {
        const llama_token id = gsmpl->prev.rat(i);

        GGML_ASSERT(id != LLAMA_TOKEN_NULL && "null token in the sampling history - should not happen");

425
        result += common_token_to_piece(ctx_main, id);
426
427
    }

428
429
430
    return result;
}

431
char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
432
    switch (cnstr) {
433
434
435
436
437
438
439
440
        case COMMON_SAMPLER_TYPE_DRY:         return 'd';
        case COMMON_SAMPLER_TYPE_TOP_K:       return 'k';
        case COMMON_SAMPLER_TYPE_TYPICAL_P:   return 'y';
        case COMMON_SAMPLER_TYPE_TOP_P:       return 'p';
        case COMMON_SAMPLER_TYPE_MIN_P:       return 'm';
        case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't';
        case COMMON_SAMPLER_TYPE_XTC:         return 'x';
        case COMMON_SAMPLER_TYPE_INFILL:      return 'i';
441
        case COMMON_SAMPLER_TYPE_PENALTIES:   return 'e';
442
443
444
        default : return '?';
    }
}
445

446
std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
447
    switch (cnstr) {
448
449
450
451
452
453
454
455
        case COMMON_SAMPLER_TYPE_DRY:         return "dry";
        case COMMON_SAMPLER_TYPE_TOP_K:       return "top_k";
        case COMMON_SAMPLER_TYPE_TYPICAL_P:   return "typ_p";
        case COMMON_SAMPLER_TYPE_TOP_P:       return "top_p";
        case COMMON_SAMPLER_TYPE_MIN_P:       return "min_p";
        case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature";
        case COMMON_SAMPLER_TYPE_XTC:         return "xtc";
        case COMMON_SAMPLER_TYPE_INFILL:      return "infill";
456
        case COMMON_SAMPLER_TYPE_PENALTIES:   return "penalties";
457
        default : return "";
458
    }
459
}
460

461
462
463
464
465
466
467
468
469
470
std::vector<common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names) {
    std::unordered_map<std::string, common_sampler_type> sampler_canonical_name_map {
        { "dry",         COMMON_SAMPLER_TYPE_DRY },
        { "top_k",       COMMON_SAMPLER_TYPE_TOP_K },
        { "top_p",       COMMON_SAMPLER_TYPE_TOP_P },
        { "typ_p",       COMMON_SAMPLER_TYPE_TYPICAL_P },
        { "min_p",       COMMON_SAMPLER_TYPE_MIN_P },
        { "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
        { "xtc",         COMMON_SAMPLER_TYPE_XTC },
        { "infill",      COMMON_SAMPLER_TYPE_INFILL },
471
        { "penalties",   COMMON_SAMPLER_TYPE_PENALTIES },
472
    };
473

474
475
    // since samplers names are written multiple ways
    // make it ready for both system names and input names
476
477
478
479
480
481
482
483
484
485
    std::unordered_map<std::string, common_sampler_type> sampler_alt_name_map {
        { "top-k",       COMMON_SAMPLER_TYPE_TOP_K },
        { "top-p",       COMMON_SAMPLER_TYPE_TOP_P },
        { "nucleus",     COMMON_SAMPLER_TYPE_TOP_P },
        { "typical-p",   COMMON_SAMPLER_TYPE_TYPICAL_P },
        { "typical",     COMMON_SAMPLER_TYPE_TYPICAL_P },
        { "typ-p",       COMMON_SAMPLER_TYPE_TYPICAL_P },
        { "typ",         COMMON_SAMPLER_TYPE_TYPICAL_P },
        { "min-p",       COMMON_SAMPLER_TYPE_MIN_P },
        { "temp",        COMMON_SAMPLER_TYPE_TEMPERATURE },
486
    };
487

488
    std::vector<common_sampler_type> samplers;
489
    samplers.reserve(names.size());
490

491
492
493
494
495
496
497
498
499
    for (const auto & name : names) {
        auto sampler = sampler_canonical_name_map.find(name);
        if (sampler != sampler_canonical_name_map.end()) {
            samplers.push_back(sampler->second);
        } else {
            if (allow_alt_names) {
                sampler = sampler_alt_name_map.find(name);
                if (sampler != sampler_alt_name_map.end()) {
                    samplers.push_back(sampler->second);
500
501
502
503
504
                }
            }
        }
    }

505
    return samplers;
506
507
}

508
509
510
511
512
513
514
515
516
517
std::vector<common_sampler_type> common_sampler_types_from_chars(const std::string & chars) {
    std::unordered_map<char, common_sampler_type> sampler_name_map = {
        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_DRY),         COMMON_SAMPLER_TYPE_DRY },
        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_K),       COMMON_SAMPLER_TYPE_TOP_K },
        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TYPICAL_P),   COMMON_SAMPLER_TYPE_TYPICAL_P },
        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_P),       COMMON_SAMPLER_TYPE_TOP_P },
        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_MIN_P),       COMMON_SAMPLER_TYPE_MIN_P },
        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE },
        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC),         COMMON_SAMPLER_TYPE_XTC },
        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL),      COMMON_SAMPLER_TYPE_INFILL },
518
        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_PENALTIES),   COMMON_SAMPLER_TYPE_PENALTIES },
519
    };
520

521
    std::vector<common_sampler_type> samplers;
522
    samplers.reserve(chars.size());
523

524
525
526
527
528
    for (const auto & c : chars) {
        const auto sampler = sampler_name_map.find(c);
        if (sampler != sampler_name_map.end()) {
            samplers.push_back(sampler->second);
        }
529
    }
530
531

    return samplers;
532
}