sampling.cpp 18.6 KB
Newer Older
1
/**
2
 * llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
 *
 * MIT License
 *
 * Copyright (c) 2023-2024 The ggml authors
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in all
 * copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 * SOFTWARE.
 */

#include "sampling.h"

29
#include "common.h"
30

31
32
#include <cmath>
#include <unordered_map>
33

34
35
36
37
38
// the ring buffer works similarly to std::deque, but with a fixed capacity
// TODO: deduplicate with llama-impl.h
template<typename T>
struct ring_buffer {
    ring_buffer(size_t cap) : capacity(cap), data(cap) {}
39

40
41
42
    T & front() {
        if (sz == 0) {
            throw std::runtime_error("ring buffer is empty");
43
        }
44
45
        return data[first];
    }
46

47
48
49
    const T & front() const {
        if (sz == 0) {
            throw std::runtime_error("ring buffer is empty");
50
        }
51
52
        return data[first];
    }
53

54
55
56
    T & back() {
        if (sz == 0) {
            throw std::runtime_error("ring buffer is empty");
57
        }
58
        return data[pos];
59
60
    }

61
62
63
64
65
    const T & back() const {
        if (sz == 0) {
            throw std::runtime_error("ring buffer is empty");
        }
        return data[pos];
66
67
    }

68
69
70
71
72
73
74
75
76
    void push_back(const T & value) {
        if (sz == capacity) {
            // advance the start when buffer is full
            first = (first + 1) % capacity;
        } else {
            sz++;
        }
        data[pos] = value;
        pos = (pos + 1) % capacity;
77
78
    }

79
80
81
82
83
84
85
86
87
    T pop_front() {
        if (sz == 0) {
            throw std::runtime_error("ring buffer is empty");
        }
        T value = data[first];
        first = (first + 1) % capacity;
        sz--;
        return value;
    }
88

89
90
91
    const T & rat(size_t i) const {
        if (i >= sz) {
            throw std::runtime_error("ring buffer: index out of bounds");
92
        }
93
        return data[(first + sz - i - 1) % capacity];
94
95
    }

96
97
98
99
100
101
102
103
    std::vector<T> to_vector() const {
        std::vector<T> result;
        result.reserve(sz);
        for (size_t i = 0; i < sz; i++) {
            result.push_back(data[(first + i) % capacity]);
        }
        return result;
    }
104

105
106
107
108
109
    void clear() {
        // here only reset the status of the buffer
        sz = 0;
        first = 0;
        pos = 0;
110
111
    }

112
113
    bool empty() const {
        return sz == 0;
114
115
    }

116
117
    size_t size() const {
        return sz;
118
119
    }

120
121
122
123
124
125
    size_t capacity = 0;
    size_t sz = 0;
    size_t first = 0;
    size_t pos = 0;
    std::vector<T> data;
};
126

127
128
struct common_sampler {
    common_params_sampling params;
129

130
131
    struct llama_sampler * grmr;
    struct llama_sampler * chain;
132

133
    ring_buffer<llama_token> prev;
134

135
    std::vector<llama_token_data> cur;
136

137
    llama_token_data_array cur_p;
138

139
140
    void set_logits(struct llama_context * ctx, int idx) {
        const auto * logits = llama_get_logits_ith(ctx, idx);
141

142
143
144
145
146
147
148
149
150
151
152
153
        const int n_vocab = llama_n_vocab(llama_get_model(ctx));

        cur.resize(n_vocab);

        for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
            cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
        }

        cur_p = { cur.data(), cur.size(), -1, false };
    }
};

154
std::string common_params_sampling::print() const {
155
156
157
158
    char result[1024];

    snprintf(result, sizeof(result),
            "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
159
160
            "\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n"
            "\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, temp = %.3f\n"
161
            "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
162
            penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
163
164
            dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n,
            top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, temp,
165
            mirostat, mirostat_eta, mirostat_tau);
166
167
168
169

    return std::string(result);
}

170
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) {
171
172
173
174
    llama_sampler_chain_params lparams = llama_sampler_chain_default_params();

    lparams.no_perf = params.no_perf;

175
    auto * result = new common_sampler {
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
        /* .params = */ params,
        /* .grmr   = */ llama_sampler_init_grammar(model, params.grammar.c_str(), "root"),
        /* .chain  = */ llama_sampler_chain_init(lparams),
        /* .prev   = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
        /* .cur    = */ {},
        /* .cur_p  = */ {},
    };

    llama_sampler_chain_add(result->chain,
            llama_sampler_init_logit_bias(
                llama_n_vocab(model),
                params.logit_bias.size(),
                params.logit_bias.data()));

    llama_sampler_chain_add(result->chain,
            llama_sampler_init_penalties(
                llama_n_vocab  (model),
                llama_token_eos(model),
                llama_token_nl (model),
                params.penalty_last_n,
                params.penalty_repeat,
                params.penalty_freq,
                params.penalty_present,
                params.penalize_nl,
                params.ignore_eos));

202
203
204
205
206
207
208
209
210
211
212
213
214
    if (params.mirostat == 0) {
        for (const auto & cnstr : params.samplers) {
            switch (cnstr) {
                    case COMMON_SAMPLER_TYPE_DRY:
                    {
                        std::vector<const char*> c_breakers;
                        c_breakers.reserve(params.dry_sequence_breakers.size());
                        for (const auto& str : params.dry_sequence_breakers) {
                            c_breakers.push_back(str.c_str());
                        }

                        llama_sampler_chain_add(result->chain, llama_sampler_init_dry      (model, params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
                    }
215
                        break;
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
                case COMMON_SAMPLER_TYPE_TOP_K:
                    llama_sampler_chain_add(result->chain, llama_sampler_init_top_k    (params.top_k));
                    break;
                case COMMON_SAMPLER_TYPE_TOP_P:
                    llama_sampler_chain_add(result->chain, llama_sampler_init_top_p    (params.top_p, params.min_keep));
                    break;
                case COMMON_SAMPLER_TYPE_MIN_P:
                    llama_sampler_chain_add(result->chain, llama_sampler_init_min_p    (params.min_p, params.min_keep));
                    break;
                case COMMON_SAMPLER_TYPE_XTC:
                    llama_sampler_chain_add(result->chain, llama_sampler_init_xtc      (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
                    break;
                case COMMON_SAMPLER_TYPE_TYPICAL_P:
                    llama_sampler_chain_add(result->chain, llama_sampler_init_typical  (params.typ_p, params.min_keep));
                    break;
                case COMMON_SAMPLER_TYPE_TEMPERATURE:
                    llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
                    break;
                case COMMON_SAMPLER_TYPE_INFILL:
                    llama_sampler_chain_add(result->chain, llama_sampler_init_infill   (model));
                    break;
                default:
                    GGML_ASSERT(false && "unknown sampler type");
239
240
            }
        }
241
242
243
244
245
246
247
        llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
    } else if (params.mirostat == 1) {
        llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
        llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_n_vocab(model), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
    } else if (params.mirostat == 2) {
        llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
        llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
248
    } else {
249
        GGML_ASSERT(false && "unknown mirostat version");
250
251
252
253
254
    }

    return result;
}

255
void common_sampler_free(struct common_sampler * gsmpl) {
256
257
258
259
260
261
    if (gsmpl) {
        llama_sampler_free(gsmpl->grmr);

        llama_sampler_free(gsmpl->chain);

        delete gsmpl;
262
263
264
    }
}

265
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
266
267
268
    if (accept_grammar) {
        llama_sampler_accept(gsmpl->grmr, token);
    }
269

270
    llama_sampler_accept(gsmpl->chain, token);
271

272
    gsmpl->prev.push_back(token);
273
274
}

275
void common_sampler_reset(struct common_sampler * gsmpl) {
276
    llama_sampler_reset(gsmpl->grmr);
277

278
    llama_sampler_reset(gsmpl->chain);
279
280
}

281
282
struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
    return new common_sampler {
283
284
285
286
287
288
289
        /* .params = */ gsmpl->params,
        /* .grmr   = */ llama_sampler_clone(gsmpl->grmr),
        /* .chain  = */ llama_sampler_clone(gsmpl->chain),
        /* .prev   = */ gsmpl->prev,
        /* .cur    = */ gsmpl->cur,
        /* .cur_p  = */ gsmpl->cur_p,
    };
290
291
}

292
void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl) {
293
    // TODO: measure grammar performance
294

295
296
297
298
299
300
301
    if (gsmpl) {
        llama_perf_sampler_print(gsmpl->chain);
    }
    if (ctx) {
        llama_perf_context_print(ctx);
    }
}
302

303
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
304
    gsmpl->set_logits(ctx, idx);
305

306
307
308
    auto & grmr  = gsmpl->grmr;
    auto & chain = gsmpl->chain;
    auto & cur_p = gsmpl->cur_p; // initialized by set_logits
309

310
311
    if (grammar_first) {
        llama_sampler_apply(grmr, &cur_p);
312
313
    }

314
    llama_sampler_apply(chain, &cur_p);
315

316
    GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
317

318
    const llama_token id = cur_p.data[cur_p.selected].id;
319

320
321
322
    if (grammar_first) {
        return id;
    }
323

324
325
326
327
    // check if it the sampled token fits the grammar
    {
        llama_token_data       single_token_data       = { id, 1.0f, 0.0f };
        llama_token_data_array single_token_data_array = { &single_token_data, 1, -1, false };
328

329
        llama_sampler_apply(grmr, &single_token_data_array);
330

331
332
333
        const bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
        if (is_valid) {
            return id;
334
335
336
        }
    }

337
338
339
    // resampling:
    // if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain
    gsmpl->set_logits(ctx, idx);
340

341
342
    llama_sampler_apply(grmr,  &cur_p);
    llama_sampler_apply(chain, &cur_p);
343

344
    GGML_ASSERT(cur_p.selected != -1 && "no selected token during re-sampling - check your sampling configuration");
345

346
347
    return cur_p.data[cur_p.selected].id;
}
348

349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first) {
    GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");

    std::vector<llama_token> result;
    result.reserve(idxs.size());

    size_t i = 0;
    for (; i < draft.size(); i++) {
        const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);

        common_sampler_accept(gsmpl, id, true);

        result.push_back(id);

        if (draft[i] != id) {
            break;
        }
    }

    if (i == draft.size()) {
        const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);

        common_sampler_accept(gsmpl, id, true);

        result.push_back(id);
    }

    return result;
}

std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) {
    std::vector<int> idxs(draft.size() + 1);
    for (size_t i = 0; i < idxs.size(); ++i) {
        idxs[i] = i;
    }

    return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first);
}

uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
389
390
391
392
    return llama_sampler_get_seed(gsmpl->chain);
}

// helpers
393

394
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl) {
395
396
    return &gsmpl->cur_p;
}
397

398
llama_token common_sampler_last(const struct common_sampler * gsmpl) {
399
400
    return gsmpl->prev.rat(0);
}
401

402
std::string common_sampler_print(const struct common_sampler * gsmpl) {
403
    std::string result = "logits ";
404

405
406
407
    for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) {
        const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
        result += std::string("-> ") + llama_sampler_name(smpl) + " ";
408
409
    }

410
411
412
    return result;
}

413
std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx_main, int n) {
414
415
416
417
    n = std::min(n, (int) gsmpl->prev.size());

    if (n <= 0) {
        return "";
418
419
    }

420
421
422
423
424
425
426
427
    std::string result;
    result.reserve(8*n); // 8 is the average length of a token [citation needed], TODO: compute this from the vocab

    for (int i = n - 1; i >= 0; i--) {
        const llama_token id = gsmpl->prev.rat(i);

        GGML_ASSERT(id != LLAMA_TOKEN_NULL && "null token in the sampling history - should not happen");

428
        result += common_token_to_piece(ctx_main, id);
429
430
    }

431
432
433
    return result;
}

434
char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
435
    switch (cnstr) {
436
437
438
439
440
441
442
443
        case COMMON_SAMPLER_TYPE_DRY:         return 'd';
        case COMMON_SAMPLER_TYPE_TOP_K:       return 'k';
        case COMMON_SAMPLER_TYPE_TYPICAL_P:   return 'y';
        case COMMON_SAMPLER_TYPE_TOP_P:       return 'p';
        case COMMON_SAMPLER_TYPE_MIN_P:       return 'm';
        case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't';
        case COMMON_SAMPLER_TYPE_XTC:         return 'x';
        case COMMON_SAMPLER_TYPE_INFILL:      return 'i';
444
445
446
        default : return '?';
    }
}
447

448
std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
449
    switch (cnstr) {
450
451
452
453
454
455
456
457
        case COMMON_SAMPLER_TYPE_DRY:         return "dry";
        case COMMON_SAMPLER_TYPE_TOP_K:       return "top_k";
        case COMMON_SAMPLER_TYPE_TYPICAL_P:   return "typ_p";
        case COMMON_SAMPLER_TYPE_TOP_P:       return "top_p";
        case COMMON_SAMPLER_TYPE_MIN_P:       return "min_p";
        case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature";
        case COMMON_SAMPLER_TYPE_XTC:         return "xtc";
        case COMMON_SAMPLER_TYPE_INFILL:      return "infill";
458
        default : return "";
459
    }
460
}
461

462
463
464
465
466
467
468
469
470
471
std::vector<common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names) {
    std::unordered_map<std::string, common_sampler_type> sampler_canonical_name_map {
        { "dry",         COMMON_SAMPLER_TYPE_DRY },
        { "top_k",       COMMON_SAMPLER_TYPE_TOP_K },
        { "top_p",       COMMON_SAMPLER_TYPE_TOP_P },
        { "typ_p",       COMMON_SAMPLER_TYPE_TYPICAL_P },
        { "min_p",       COMMON_SAMPLER_TYPE_MIN_P },
        { "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
        { "xtc",         COMMON_SAMPLER_TYPE_XTC },
        { "infill",      COMMON_SAMPLER_TYPE_INFILL },
472
    };
473

474
475
    // since samplers names are written multiple ways
    // make it ready for both system names and input names
476
477
478
479
480
481
482
483
484
485
    std::unordered_map<std::string, common_sampler_type> sampler_alt_name_map {
        { "top-k",       COMMON_SAMPLER_TYPE_TOP_K },
        { "top-p",       COMMON_SAMPLER_TYPE_TOP_P },
        { "nucleus",     COMMON_SAMPLER_TYPE_TOP_P },
        { "typical-p",   COMMON_SAMPLER_TYPE_TYPICAL_P },
        { "typical",     COMMON_SAMPLER_TYPE_TYPICAL_P },
        { "typ-p",       COMMON_SAMPLER_TYPE_TYPICAL_P },
        { "typ",         COMMON_SAMPLER_TYPE_TYPICAL_P },
        { "min-p",       COMMON_SAMPLER_TYPE_MIN_P },
        { "temp",        COMMON_SAMPLER_TYPE_TEMPERATURE },
486
    };
487

488
    std::vector<common_sampler_type> samplers;
489
    samplers.reserve(names.size());
490

491
492
493
494
495
496
497
498
499
    for (const auto & name : names) {
        auto sampler = sampler_canonical_name_map.find(name);
        if (sampler != sampler_canonical_name_map.end()) {
            samplers.push_back(sampler->second);
        } else {
            if (allow_alt_names) {
                sampler = sampler_alt_name_map.find(name);
                if (sampler != sampler_alt_name_map.end()) {
                    samplers.push_back(sampler->second);
500
501
502
503
504
                }
            }
        }
    }

505
    return samplers;
506
507
}

508
509
510
511
512
513
514
515
516
517
std::vector<common_sampler_type> common_sampler_types_from_chars(const std::string & chars) {
    std::unordered_map<char, common_sampler_type> sampler_name_map = {
        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_DRY),         COMMON_SAMPLER_TYPE_DRY },
        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_K),       COMMON_SAMPLER_TYPE_TOP_K },
        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TYPICAL_P),   COMMON_SAMPLER_TYPE_TYPICAL_P },
        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_P),       COMMON_SAMPLER_TYPE_TOP_P },
        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_MIN_P),       COMMON_SAMPLER_TYPE_MIN_P },
        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE },
        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC),         COMMON_SAMPLER_TYPE_XTC },
        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL),      COMMON_SAMPLER_TYPE_INFILL },
518
    };
519

520
    std::vector<common_sampler_type> samplers;
521
    samplers.reserve(chars.size());
522

523
524
525
526
527
    for (const auto & c : chars) {
        const auto sampler = sampler_name_map.find(c);
        if (sampler != sampler_name_map.end()) {
            samplers.push_back(sampler->second);
        }
528
    }
529
530

    return samplers;
531
}