sampling.cpp 18.7 KB
Newer Older
1
/**
2
 * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
 *
 * MIT License
 *
 * Copyright (c) 2023-2024 The ggml authors
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in all
 * copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 * SOFTWARE.
 */

#include "sampling.h"

29
#include "common.h"
30

31
32
#include <cmath>
#include <unordered_map>
33

34
35
36
37
38
// the ring buffer works similarly to std::deque, but with a fixed capacity
// TODO: deduplicate with llama-impl.h
template<typename T>
struct ring_buffer {
    ring_buffer(size_t cap) : capacity(cap), data(cap) {}
39

40
41
42
    T & front() {
        if (sz == 0) {
            throw std::runtime_error("ring buffer is empty");
43
        }
44
45
        return data[first];
    }
46

47
48
49
    const T & front() const {
        if (sz == 0) {
            throw std::runtime_error("ring buffer is empty");
50
        }
51
52
        return data[first];
    }
53

54
55
56
    T & back() {
        if (sz == 0) {
            throw std::runtime_error("ring buffer is empty");
57
        }
58
        return data[pos];
59
60
    }

61
62
63
64
65
    const T & back() const {
        if (sz == 0) {
            throw std::runtime_error("ring buffer is empty");
        }
        return data[pos];
66
67
    }

68
69
70
71
72
73
74
75
76
    void push_back(const T & value) {
        if (sz == capacity) {
            // advance the start when buffer is full
            first = (first + 1) % capacity;
        } else {
            sz++;
        }
        data[pos] = value;
        pos = (pos + 1) % capacity;
77
78
    }

79
80
81
82
83
84
85
86
87
    T pop_front() {
        if (sz == 0) {
            throw std::runtime_error("ring buffer is empty");
        }
        T value = data[first];
        first = (first + 1) % capacity;
        sz--;
        return value;
    }
88

89
90
91
    const T & rat(size_t i) const {
        if (i >= sz) {
            throw std::runtime_error("ring buffer: index out of bounds");
92
        }
93
        return data[(first + sz - i - 1) % capacity];
94
95
    }

96
97
98
99
100
101
102
103
    std::vector<T> to_vector() const {
        std::vector<T> result;
        result.reserve(sz);
        for (size_t i = 0; i < sz; i++) {
            result.push_back(data[(first + i) % capacity]);
        }
        return result;
    }
104

105
106
107
108
109
    void clear() {
        // here only reset the status of the buffer
        sz = 0;
        first = 0;
        pos = 0;
110
111
    }

112
113
    bool empty() const {
        return sz == 0;
114
115
    }

116
117
    size_t size() const {
        return sz;
118
119
    }

120
121
122
123
124
125
    size_t capacity = 0;
    size_t sz = 0;
    size_t first = 0;
    size_t pos = 0;
    std::vector<T> data;
};
126

127
128
struct common_sampler {
    common_params_sampling params;
129

130
131
    struct llama_sampler * grmr;
    struct llama_sampler * chain;
132

133
    ring_buffer<llama_token> prev;
134

135
    std::vector<llama_token_data> cur;
136

137
    llama_token_data_array cur_p;
138

139
140
    void set_logits(struct llama_context * ctx, int idx) {
        const auto * logits = llama_get_logits_ith(ctx, idx);
141

142
143
144
145
146
147
148
149
150
151
152
153
        const int n_vocab = llama_n_vocab(llama_get_model(ctx));

        cur.resize(n_vocab);

        for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
            cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
        }

        cur_p = { cur.data(), cur.size(), -1, false };
    }
};

154
std::string common_params_sampling::print() const {
155
156
157
158
    char result[1024];

    snprintf(result, sizeof(result),
            "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
159
160
            "\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n"
            "\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, temp = %.3f\n"
161
            "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
162
            penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
163
164
            dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n,
            top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, temp,
165
            mirostat, mirostat_eta, mirostat_tau);
166
167
168
169

    return std::string(result);
}

170
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) {
171
172
173
174
    llama_sampler_chain_params lparams = llama_sampler_chain_default_params();

    lparams.no_perf = params.no_perf;

175
    auto * result = new common_sampler {
176
177
178
179
180
181
182
183
184
185
186
187
188
189
        /* .params = */ params,
        /* .grmr   = */ llama_sampler_init_grammar(model, params.grammar.c_str(), "root"),
        /* .chain  = */ llama_sampler_chain_init(lparams),
        /* .prev   = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
        /* .cur    = */ {},
        /* .cur_p  = */ {},
    };

    llama_sampler_chain_add(result->chain,
            llama_sampler_init_logit_bias(
                llama_n_vocab(model),
                params.logit_bias.size(),
                params.logit_bias.data()));

190
191
192
    if (params.mirostat == 0) {
        for (const auto & cnstr : params.samplers) {
            switch (cnstr) {
193
                case COMMON_SAMPLER_TYPE_DRY:
194
                    {
195
                        std::vector<const char *> c_breakers;
196
                        c_breakers.reserve(params.dry_sequence_breakers.size());
197
                        for (const auto & str : params.dry_sequence_breakers) {
198
199
200
201
202
                            c_breakers.push_back(str.c_str());
                        }

                        llama_sampler_chain_add(result->chain, llama_sampler_init_dry      (model, params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
                    }
203
                    break;
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
                case COMMON_SAMPLER_TYPE_TOP_K:
                    llama_sampler_chain_add(result->chain, llama_sampler_init_top_k    (params.top_k));
                    break;
                case COMMON_SAMPLER_TYPE_TOP_P:
                    llama_sampler_chain_add(result->chain, llama_sampler_init_top_p    (params.top_p, params.min_keep));
                    break;
                case COMMON_SAMPLER_TYPE_MIN_P:
                    llama_sampler_chain_add(result->chain, llama_sampler_init_min_p    (params.min_p, params.min_keep));
                    break;
                case COMMON_SAMPLER_TYPE_XTC:
                    llama_sampler_chain_add(result->chain, llama_sampler_init_xtc      (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
                    break;
                case COMMON_SAMPLER_TYPE_TYPICAL_P:
                    llama_sampler_chain_add(result->chain, llama_sampler_init_typical  (params.typ_p, params.min_keep));
                    break;
                case COMMON_SAMPLER_TYPE_TEMPERATURE:
                    llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
                    break;
                case COMMON_SAMPLER_TYPE_INFILL:
                    llama_sampler_chain_add(result->chain, llama_sampler_init_infill   (model));
                    break;
225
226
227
                case COMMON_SAMPLER_TYPE_PENALTIES:
                    llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
                    break;
228
229
                default:
                    GGML_ASSERT(false && "unknown sampler type");
230
231
            }
        }
232
233
234
235
236
237
238
        llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
    } else if (params.mirostat == 1) {
        llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
        llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_n_vocab(model), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
    } else if (params.mirostat == 2) {
        llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
        llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
239
    } else {
240
        GGML_ASSERT(false && "unknown mirostat version");
241
242
243
244
245
    }

    return result;
}

246
void common_sampler_free(struct common_sampler * gsmpl) {
247
248
249
250
251
252
    if (gsmpl) {
        llama_sampler_free(gsmpl->grmr);

        llama_sampler_free(gsmpl->chain);

        delete gsmpl;
253
254
255
    }
}

256
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
257
258
259
    if (accept_grammar) {
        llama_sampler_accept(gsmpl->grmr, token);
    }
260

261
    llama_sampler_accept(gsmpl->chain, token);
262

263
    gsmpl->prev.push_back(token);
264
265
}

266
void common_sampler_reset(struct common_sampler * gsmpl) {
267
    llama_sampler_reset(gsmpl->grmr);
268

269
    llama_sampler_reset(gsmpl->chain);
270
271
}

272
273
struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
    return new common_sampler {
274
275
276
277
278
279
280
        /* .params = */ gsmpl->params,
        /* .grmr   = */ llama_sampler_clone(gsmpl->grmr),
        /* .chain  = */ llama_sampler_clone(gsmpl->chain),
        /* .prev   = */ gsmpl->prev,
        /* .cur    = */ gsmpl->cur,
        /* .cur_p  = */ gsmpl->cur_p,
    };
281
282
}

283
void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl) {
284
    // TODO: measure grammar performance
285

286
287
288
289
290
291
292
    if (gsmpl) {
        llama_perf_sampler_print(gsmpl->chain);
    }
    if (ctx) {
        llama_perf_context_print(ctx);
    }
}
293

294
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
295
    gsmpl->set_logits(ctx, idx);
296

297
298
299
    auto & grmr  = gsmpl->grmr;
    auto & chain = gsmpl->chain;
    auto & cur_p = gsmpl->cur_p; // initialized by set_logits
300

301
302
    if (grammar_first) {
        llama_sampler_apply(grmr, &cur_p);
303
304
    }

305
    llama_sampler_apply(chain, &cur_p);
306

307
    GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
308

309
    const llama_token id = cur_p.data[cur_p.selected].id;
310

311
312
313
    if (grammar_first) {
        return id;
    }
314

315
316
317
318
    // check if it the sampled token fits the grammar
    {
        llama_token_data       single_token_data       = { id, 1.0f, 0.0f };
        llama_token_data_array single_token_data_array = { &single_token_data, 1, -1, false };
319

320
        llama_sampler_apply(grmr, &single_token_data_array);
321

322
323
324
        const bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
        if (is_valid) {
            return id;
325
326
327
        }
    }

328
329
330
    // resampling:
    // if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain
    gsmpl->set_logits(ctx, idx);
331

332
333
    llama_sampler_apply(grmr,  &cur_p);
    llama_sampler_apply(chain, &cur_p);
334

335
    GGML_ASSERT(cur_p.selected != -1 && "no selected token during re-sampling - check your sampling configuration");
336

337
338
    return cur_p.data[cur_p.selected].id;
}
339

340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first) {
    GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");

    std::vector<llama_token> result;
    result.reserve(idxs.size());

    size_t i = 0;
    for (; i < draft.size(); i++) {
        const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);

        common_sampler_accept(gsmpl, id, true);

        result.push_back(id);

        if (draft[i] != id) {
            break;
        }
    }

    if (i == draft.size()) {
        const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);

        common_sampler_accept(gsmpl, id, true);

        result.push_back(id);
    }

    return result;
}

std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) {
    std::vector<int> idxs(draft.size() + 1);
    for (size_t i = 0; i < idxs.size(); ++i) {
        idxs[i] = i;
    }

    return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first);
}

uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
380
381
382
383
    return llama_sampler_get_seed(gsmpl->chain);
}

// helpers
384

385
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl) {
386
387
    return &gsmpl->cur_p;
}
388

389
llama_token common_sampler_last(const struct common_sampler * gsmpl) {
390
391
    return gsmpl->prev.rat(0);
}
392

393
std::string common_sampler_print(const struct common_sampler * gsmpl) {
394
    std::string result = "logits ";
395

396
397
398
    for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) {
        const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
        result += std::string("-> ") + llama_sampler_name(smpl) + " ";
399
400
    }

401
402
403
    return result;
}

404
std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx_main, int n) {
405
406
407
408
    n = std::min(n, (int) gsmpl->prev.size());

    if (n <= 0) {
        return "";
409
410
    }

411
412
413
414
415
416
417
418
    std::string result;
    result.reserve(8*n); // 8 is the average length of a token [citation needed], TODO: compute this from the vocab

    for (int i = n - 1; i >= 0; i--) {
        const llama_token id = gsmpl->prev.rat(i);

        GGML_ASSERT(id != LLAMA_TOKEN_NULL && "null token in the sampling history - should not happen");

419
        result += common_token_to_piece(ctx_main, id);
420
421
    }

422
423
424
    return result;
}

425
char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
426
    switch (cnstr) {
427
428
429
430
431
432
433
434
        case COMMON_SAMPLER_TYPE_DRY:         return 'd';
        case COMMON_SAMPLER_TYPE_TOP_K:       return 'k';
        case COMMON_SAMPLER_TYPE_TYPICAL_P:   return 'y';
        case COMMON_SAMPLER_TYPE_TOP_P:       return 'p';
        case COMMON_SAMPLER_TYPE_MIN_P:       return 'm';
        case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't';
        case COMMON_SAMPLER_TYPE_XTC:         return 'x';
        case COMMON_SAMPLER_TYPE_INFILL:      return 'i';
435
        case COMMON_SAMPLER_TYPE_PENALTIES:   return 'e';
436
437
438
        default : return '?';
    }
}
439

440
std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
441
    switch (cnstr) {
442
443
444
445
446
447
448
449
        case COMMON_SAMPLER_TYPE_DRY:         return "dry";
        case COMMON_SAMPLER_TYPE_TOP_K:       return "top_k";
        case COMMON_SAMPLER_TYPE_TYPICAL_P:   return "typ_p";
        case COMMON_SAMPLER_TYPE_TOP_P:       return "top_p";
        case COMMON_SAMPLER_TYPE_MIN_P:       return "min_p";
        case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature";
        case COMMON_SAMPLER_TYPE_XTC:         return "xtc";
        case COMMON_SAMPLER_TYPE_INFILL:      return "infill";
450
        case COMMON_SAMPLER_TYPE_PENALTIES:   return "penalties";
451
        default : return "";
452
    }
453
}
454

455
456
457
458
459
460
461
462
463
464
std::vector<common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names) {
    std::unordered_map<std::string, common_sampler_type> sampler_canonical_name_map {
        { "dry",         COMMON_SAMPLER_TYPE_DRY },
        { "top_k",       COMMON_SAMPLER_TYPE_TOP_K },
        { "top_p",       COMMON_SAMPLER_TYPE_TOP_P },
        { "typ_p",       COMMON_SAMPLER_TYPE_TYPICAL_P },
        { "min_p",       COMMON_SAMPLER_TYPE_MIN_P },
        { "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
        { "xtc",         COMMON_SAMPLER_TYPE_XTC },
        { "infill",      COMMON_SAMPLER_TYPE_INFILL },
465
        { "penalties",   COMMON_SAMPLER_TYPE_PENALTIES },
466
    };
467

468
469
    // since samplers names are written multiple ways
    // make it ready for both system names and input names
470
471
472
473
474
475
476
477
478
479
    std::unordered_map<std::string, common_sampler_type> sampler_alt_name_map {
        { "top-k",       COMMON_SAMPLER_TYPE_TOP_K },
        { "top-p",       COMMON_SAMPLER_TYPE_TOP_P },
        { "nucleus",     COMMON_SAMPLER_TYPE_TOP_P },
        { "typical-p",   COMMON_SAMPLER_TYPE_TYPICAL_P },
        { "typical",     COMMON_SAMPLER_TYPE_TYPICAL_P },
        { "typ-p",       COMMON_SAMPLER_TYPE_TYPICAL_P },
        { "typ",         COMMON_SAMPLER_TYPE_TYPICAL_P },
        { "min-p",       COMMON_SAMPLER_TYPE_MIN_P },
        { "temp",        COMMON_SAMPLER_TYPE_TEMPERATURE },
480
    };
481

482
    std::vector<common_sampler_type> samplers;
483
    samplers.reserve(names.size());
484

485
486
487
488
489
490
491
492
493
    for (const auto & name : names) {
        auto sampler = sampler_canonical_name_map.find(name);
        if (sampler != sampler_canonical_name_map.end()) {
            samplers.push_back(sampler->second);
        } else {
            if (allow_alt_names) {
                sampler = sampler_alt_name_map.find(name);
                if (sampler != sampler_alt_name_map.end()) {
                    samplers.push_back(sampler->second);
494
495
496
497
498
                }
            }
        }
    }

499
    return samplers;
500
501
}

502
503
504
505
506
507
508
509
510
511
std::vector<common_sampler_type> common_sampler_types_from_chars(const std::string & chars) {
    std::unordered_map<char, common_sampler_type> sampler_name_map = {
        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_DRY),         COMMON_SAMPLER_TYPE_DRY },
        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_K),       COMMON_SAMPLER_TYPE_TOP_K },
        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TYPICAL_P),   COMMON_SAMPLER_TYPE_TYPICAL_P },
        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_P),       COMMON_SAMPLER_TYPE_TOP_P },
        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_MIN_P),       COMMON_SAMPLER_TYPE_MIN_P },
        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE },
        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC),         COMMON_SAMPLER_TYPE_XTC },
        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL),      COMMON_SAMPLER_TYPE_INFILL },
512
        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_PENALTIES),   COMMON_SAMPLER_TYPE_PENALTIES },
513
    };
514

515
    std::vector<common_sampler_type> samplers;
516
    samplers.reserve(chars.size());
517

518
519
520
521
522
    for (const auto & c : chars) {
        const auto sampler = sampler_name_map.find(c);
        if (sampler != sampler_name_map.end()) {
            samplers.push_back(sampler->second);
        }
523
    }
524
525

    return samplers;
526
}