sampling.cpp 23.8 KB
Newer Older
1
2
#include "sampling.h"

3
#include "common.h"
4
#include "log.h"
5

Daniel Hiltgen's avatar
Daniel Hiltgen committed
6
#include <algorithm>
7
#include <cmath>
Daniel Hiltgen's avatar
Daniel Hiltgen committed
8
#include <cstring>
9
#include <unordered_map>
10

11
12
13
14
15
// the ring buffer works similarly to std::deque, but with a fixed capacity
// TODO: deduplicate with llama-impl.h
template<typename T>
struct ring_buffer {
    ring_buffer(size_t cap) : capacity(cap), data(cap) {}
16

17
18
19
    T & front() {
        if (sz == 0) {
            throw std::runtime_error("ring buffer is empty");
20
        }
21
22
        return data[first];
    }
23

24
25
26
    const T & front() const {
        if (sz == 0) {
            throw std::runtime_error("ring buffer is empty");
27
        }
28
29
        return data[first];
    }
30

31
32
33
    T & back() {
        if (sz == 0) {
            throw std::runtime_error("ring buffer is empty");
34
        }
35
        return data[pos];
36
37
    }

38
39
40
41
42
    const T & back() const {
        if (sz == 0) {
            throw std::runtime_error("ring buffer is empty");
        }
        return data[pos];
43
44
    }

45
46
47
48
49
50
51
52
53
    void push_back(const T & value) {
        if (sz == capacity) {
            // advance the start when buffer is full
            first = (first + 1) % capacity;
        } else {
            sz++;
        }
        data[pos] = value;
        pos = (pos + 1) % capacity;
54
55
    }

56
57
58
59
60
61
62
63
64
    T pop_front() {
        if (sz == 0) {
            throw std::runtime_error("ring buffer is empty");
        }
        T value = data[first];
        first = (first + 1) % capacity;
        sz--;
        return value;
    }
65

66
67
68
    const T & rat(size_t i) const {
        if (i >= sz) {
            throw std::runtime_error("ring buffer: index out of bounds");
69
        }
70
        return data[(first + sz - i - 1) % capacity];
71
72
    }

73
74
75
76
77
78
79
80
    std::vector<T> to_vector() const {
        std::vector<T> result;
        result.reserve(sz);
        for (size_t i = 0; i < sz; i++) {
            result.push_back(data[(first + i) % capacity]);
        }
        return result;
    }
81

82
83
84
85
86
    void clear() {
        // here only reset the status of the buffer
        sz = 0;
        first = 0;
        pos = 0;
87
88
    }

89
90
    bool empty() const {
        return sz == 0;
91
92
    }

93
94
    size_t size() const {
        return sz;
95
96
    }

97
98
99
100
101
102
    size_t capacity = 0;
    size_t sz = 0;
    size_t first = 0;
    size_t pos = 0;
    std::vector<T> data;
};
103

104
105
struct common_sampler {
    common_params_sampling params;
106

107
108
    struct llama_sampler * grmr;
    struct llama_sampler * chain;
109

110
    ring_buffer<llama_token> prev;
111

112
    std::vector<llama_token_data> cur;
113

114
    llama_token_data_array cur_p;
115

Daniel Hiltgen's avatar
Daniel Hiltgen committed
116
117
118
119
120
121
122
    void reset() {
        prev.clear();

        llama_sampler_reset(grmr);
        llama_sampler_reset(chain);
    }

123
124
    void set_logits(struct llama_context * ctx, int idx) {
        const auto * logits = llama_get_logits_ith(ctx, idx);
125

126
127
128
129
        const llama_model * model = llama_get_model(ctx);
        const llama_vocab * vocab = llama_model_get_vocab(model);

        const int n_vocab = llama_vocab_n_tokens(vocab);
130
131
132
133
134
135
136
137
138

        cur.resize(n_vocab);

        for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
            cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
        }

        cur_p = { cur.data(), cur.size(), -1, false };
    }
Daniel Hiltgen's avatar
Daniel Hiltgen committed
139
140
141
142
143
144

    common_time_meas tm() {
        return common_time_meas(t_total_us, params.no_perf);
    }

    mutable int64_t t_total_us = 0;
145
146
};

147
std::string common_params_sampling::print() const {
148
149
150
151
    char result[1024];

    snprintf(result, sizeof(result),
            "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
152
            "\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n"
153
            "\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, top_n_sigma = %.3f, temp = %.3f\n"
154
            "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
155
            penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
156
            dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n,
157
            top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, top_n_sigma, temp,
158
            mirostat, mirostat_eta, mirostat_tau);
159
160
161
162

    return std::string(result);
}

163
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) {
164
165
    const llama_vocab * vocab = llama_model_get_vocab(model);

166
167
168
169
    llama_sampler_chain_params lparams = llama_sampler_chain_default_params();

    lparams.no_perf = params.no_perf;

170
171
172
173
174
175
176
177
    struct llama_sampler * grmr;
    if (params.grammar.compare(0, 11, "%llguidance") == 0) {
#ifdef LLAMA_USE_LLGUIDANCE
        grmr = llama_sampler_init_llg(vocab, "lark", params.grammar.c_str());
#else
        GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
#endif // LLAMA_USE_LLGUIDANCE
    } else {
178
        std::vector<std::string> trigger_patterns;
179
180
181
182
183
184
185
186
187
188
189
190
        std::vector<std::string> patterns_anywhere;
        std::vector<llama_token> trigger_tokens;
        for (const auto & trigger : params.grammar_triggers) {
            switch (trigger.type) {
                case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
                {
                    const auto & word = trigger.value;
                    patterns_anywhere.push_back(regex_escape(word));
                    break;
                }
                case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
                {
191
192
193
194
195
196
                    patterns_anywhere.push_back(trigger.value);
                    break;
                }
                case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL:
                {
                    trigger_patterns.push_back(trigger.value);
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
                    break;
                }
                case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN:
                {
                    const auto token = trigger.token;
                    trigger_tokens.push_back(token);
                    break;
                }
                default:
                    GGML_ASSERT(false && "unknown trigger type");
            }
        }

        if (!patterns_anywhere.empty()) {
            trigger_patterns.push_back("^[\\s\\S]*?(" + string_join(patterns_anywhere, "|") + ")[\\s\\S]*");
        }

        std::vector<const char *> trigger_patterns_c;
        trigger_patterns_c.reserve(trigger_patterns.size());
        for (const auto & regex : trigger_patterns) {
            trigger_patterns_c.push_back(regex.c_str());
218
219
220
        }

        grmr = params.grammar_lazy
221
222
223
             ? llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
                                                        trigger_patterns_c.data(), trigger_patterns_c.size(),
                                                        trigger_tokens.data(), trigger_tokens.size())
224
             :      llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
225
226
227
        if (!grmr) {
            return nullptr;
        }
228
229
    }

230
    auto * result = new common_sampler {
231
        /* .params = */ params,
232
        /* .grmr   = */ grmr,
233
234
235
236
237
238
239
240
        /* .chain  = */ llama_sampler_chain_init(lparams),
        /* .prev   = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
        /* .cur    = */ {},
        /* .cur_p  = */ {},
    };

    llama_sampler_chain_add(result->chain,
            llama_sampler_init_logit_bias(
241
                llama_vocab_n_tokens(vocab),
242
243
244
                params.logit_bias.size(),
                params.logit_bias.data()));

245
    if (params.mirostat == 0) {
246
247
248
249
250
251
252
253
        for (const auto & cnstr : params.samplers) {
            switch (cnstr) {
                case COMMON_SAMPLER_TYPE_DRY:
                    {
                        std::vector<const char *> c_breakers;
                        c_breakers.reserve(params.dry_sequence_breakers.size());
                        for (const auto & str : params.dry_sequence_breakers) {
                            c_breakers.push_back(str.c_str());
254
                        }
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287

                        llama_sampler_chain_add(result->chain, llama_sampler_init_dry      (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
                    }
                    break;
                case COMMON_SAMPLER_TYPE_TOP_K:
                    llama_sampler_chain_add(result->chain, llama_sampler_init_top_k       (params.top_k));
                    break;
                case COMMON_SAMPLER_TYPE_TOP_P:
                    llama_sampler_chain_add(result->chain, llama_sampler_init_top_p       (params.top_p, params.min_keep));
                    break;
                case COMMON_SAMPLER_TYPE_TOP_N_SIGMA:
                    llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma (params.top_n_sigma));
                    break;
                case COMMON_SAMPLER_TYPE_MIN_P:
                    llama_sampler_chain_add(result->chain, llama_sampler_init_min_p       (params.min_p, params.min_keep));
                    break;
                case COMMON_SAMPLER_TYPE_XTC:
                    llama_sampler_chain_add(result->chain, llama_sampler_init_xtc         (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
                    break;
                case COMMON_SAMPLER_TYPE_TYPICAL_P:
                    llama_sampler_chain_add(result->chain, llama_sampler_init_typical     (params.typ_p, params.min_keep));
                    break;
                case COMMON_SAMPLER_TYPE_TEMPERATURE:
                    llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext    (params.temp, params.dynatemp_range, params.dynatemp_exponent));
                    break;
                case COMMON_SAMPLER_TYPE_INFILL:
                    llama_sampler_chain_add(result->chain, llama_sampler_init_infill      (vocab));
                    break;
                case COMMON_SAMPLER_TYPE_PENALTIES:
                    llama_sampler_chain_add(result->chain, llama_sampler_init_penalties   (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
                    break;
                default:
                    GGML_ASSERT(false && "unknown sampler type");
288
289
            }
        }
290
291
292
        llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
    } else if (params.mirostat == 1) {
        llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
293
        llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
294
295
296
    } else if (params.mirostat == 2) {
        llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
        llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
297
    } else {
298
        GGML_ASSERT(false && "unknown mirostat version");
299
300
301
302
303
    }

    return result;
}

304
void common_sampler_free(struct common_sampler * gsmpl) {
305
306
307
308
309
310
    if (gsmpl) {
        llama_sampler_free(gsmpl->grmr);

        llama_sampler_free(gsmpl->chain);

        delete gsmpl;
311
312
313
    }
}

314
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
Daniel Hiltgen's avatar
Daniel Hiltgen committed
315
316
    const auto tm = gsmpl->tm();

317
318
319
    if (accept_grammar) {
        llama_sampler_accept(gsmpl->grmr, token);
    }
320

321
    llama_sampler_accept(gsmpl->chain, token);
322

323
    gsmpl->prev.push_back(token);
324
325
}

326
void common_sampler_reset(struct common_sampler * gsmpl) {
Daniel Hiltgen's avatar
Daniel Hiltgen committed
327
    gsmpl->reset();
328
329
}

330
331
struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
    return new common_sampler {
332
333
334
335
336
337
338
        /* .params = */ gsmpl->params,
        /* .grmr   = */ llama_sampler_clone(gsmpl->grmr),
        /* .chain  = */ llama_sampler_clone(gsmpl->chain),
        /* .prev   = */ gsmpl->prev,
        /* .cur    = */ gsmpl->cur,
        /* .cur_p  = */ gsmpl->cur_p,
    };
339
340
}

341
void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl) {
342
    // TODO: measure grammar performance
343

Daniel Hiltgen's avatar
Daniel Hiltgen committed
344
345
346
347
348
349
350
351
    const double t_sampling_ms = gsmpl ? 1e-3*gsmpl->t_total_us : 0;

    llama_perf_sampler_data data_smpl;
    llama_perf_context_data data_ctx;

    memset(&data_smpl, 0, sizeof(data_smpl));
    memset(&data_ctx,  0, sizeof(data_ctx));

352
    if (gsmpl) {
Daniel Hiltgen's avatar
Daniel Hiltgen committed
353
354
355
356
357
358
359
        auto & data = data_smpl;

        data = llama_perf_sampler(gsmpl->chain);

        // note: the sampling time includes the samplers time + extra time spent in common/sampling
        LOG_INF("%s:    sampling time = %10.2f ms\n", __func__, t_sampling_ms);
        LOG_INF("%s:    samplers time = %10.2f ms / %5d tokens\n", __func__, data.t_sample_ms, data.n_sample);
360
    }
Daniel Hiltgen's avatar
Daniel Hiltgen committed
361

362
    if (ctx) {
Daniel Hiltgen's avatar
Daniel Hiltgen committed
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
        auto & data = data_ctx;

        data = llama_perf_context(ctx);

        const double t_end_ms = 1e-3 * ggml_time_us();

        const double t_total_ms = t_end_ms - data.t_start_ms;
        const double t_unacc_ms = t_total_ms - (t_sampling_ms + data.t_p_eval_ms + data.t_eval_ms);
        const double t_unacc_pc = 100.0 * t_unacc_ms /  t_total_ms;

        LOG_INF("%s:        load time = %10.2f ms\n", __func__, data.t_load_ms);
        LOG_INF("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
                __func__, data.t_p_eval_ms, data.n_p_eval, data.t_p_eval_ms / data.n_p_eval, 1e3 / data.t_p_eval_ms * data.n_p_eval);
        LOG_INF("%s:        eval time = %10.2f ms / %5d runs   (%8.2f ms per token, %8.2f tokens per second)\n",
                __func__, data.t_eval_ms, data.n_eval, data.t_eval_ms / data.n_eval, 1e3 / data.t_eval_ms * data.n_eval);
        LOG_INF("%s:       total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - data.t_start_ms), (data.n_p_eval + data.n_eval));
        LOG_INF("%s: unaccounted time = %10.2f ms / %5.1f %%      (total - sampling - prompt eval - eval) / (total)\n", __func__, t_unacc_ms, t_unacc_pc);
        LOG_INF("%s:    graphs reused = %10d\n", __func__, data.n_reused);

Daniel Hiltgen's avatar
Daniel Hiltgen committed
382
        llama_memory_breakdown_print(ctx);
383
384
    }
}
385

386
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
Daniel Hiltgen's avatar
Daniel Hiltgen committed
387
388
389
390
391
    llama_synchronize(ctx);

    // start measuring sampling time after the llama_context synchronization in order to not measure any ongoing async operations
    const auto tm = gsmpl->tm();

392
    gsmpl->set_logits(ctx, idx);
393

394
395
396
    auto & grmr  = gsmpl->grmr;
    auto & chain = gsmpl->chain;
    auto & cur_p = gsmpl->cur_p; // initialized by set_logits
397

398
399
    if (grammar_first) {
        llama_sampler_apply(grmr, &cur_p);
400
401
    }

402
    llama_sampler_apply(chain, &cur_p);
403

404
    GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
405

406
    const llama_token id = cur_p.data[cur_p.selected].id;
407

408
409
410
    if (grammar_first) {
        return id;
    }
411

412
413
414
415
    // check if it the sampled token fits the grammar
    {
        llama_token_data       single_token_data       = { id, 1.0f, 0.0f };
        llama_token_data_array single_token_data_array = { &single_token_data, 1, -1, false };
416

417
        llama_sampler_apply(grmr, &single_token_data_array);
418

419
420
421
        const bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
        if (is_valid) {
            return id;
422
423
424
        }
    }

425
426
427
    // resampling:
    // if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain
    gsmpl->set_logits(ctx, idx);
428

429
430
    llama_sampler_apply(grmr,  &cur_p);
    llama_sampler_apply(chain, &cur_p);
431

432
    GGML_ASSERT(cur_p.selected != -1 && "no selected token during re-sampling - check your sampling configuration");
433

434
435
    return cur_p.data[cur_p.selected].id;
}
436

437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first) {
    GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");

    std::vector<llama_token> result;
    result.reserve(idxs.size());

    size_t i = 0;
    for (; i < draft.size(); i++) {
        const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);

        common_sampler_accept(gsmpl, id, true);

        result.push_back(id);

        if (draft[i] != id) {
            break;
        }
    }

    if (i == draft.size()) {
        const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);

        common_sampler_accept(gsmpl, id, true);

        result.push_back(id);
    }

    return result;
}

std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) {
    std::vector<int> idxs(draft.size() + 1);
    for (size_t i = 0; i < idxs.size(); ++i) {
        idxs[i] = i;
    }

    return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first);
}

uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
477
478
479
480
    return llama_sampler_get_seed(gsmpl->chain);
}

// helpers
481

Daniel Hiltgen's avatar
Daniel Hiltgen committed
482
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl, bool do_sort) {
Daniel Hiltgen's avatar
Daniel Hiltgen committed
483
484
    const auto tm = gsmpl->tm();

Daniel Hiltgen's avatar
Daniel Hiltgen committed
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
    auto * res = &gsmpl->cur_p;

    if (do_sort && !res->sorted) {
        // remember the selected token before sorting
        const llama_token id = res->data[res->selected].id;

        std::sort(res->data, res->data + res->size, [](const llama_token_data & a, const llama_token_data & b) {
            return a.p > b.p;
        });

        // restore the selected token after sorting
        for (size_t i = 0; i < res->size; ++i) {
            if (res->data[i].id == id) {
                res->selected = i;
                break;
            }
        }

        res->sorted = true;
    }

    return res;
507
}
508

509
llama_token common_sampler_last(const struct common_sampler * gsmpl) {
510
511
    return gsmpl->prev.rat(0);
}
512

513
std::string common_sampler_print(const struct common_sampler * gsmpl) {
514
    std::string result = "logits ";
515

516
517
518
    for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) {
        const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
        result += std::string("-> ") + llama_sampler_name(smpl) + " ";
519
520
    }

521
522
523
    return result;
}

524
std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx_main, int n) {
525
526
527
528
    n = std::min(n, (int) gsmpl->prev.size());

    if (n <= 0) {
        return "";
529
530
    }

531
532
533
534
535
536
537
538
    std::string result;
    result.reserve(8*n); // 8 is the average length of a token [citation needed], TODO: compute this from the vocab

    for (int i = n - 1; i >= 0; i--) {
        const llama_token id = gsmpl->prev.rat(i);

        GGML_ASSERT(id != LLAMA_TOKEN_NULL && "null token in the sampling history - should not happen");

539
        result += common_token_to_piece(ctx_main, id);
540
541
    }

542
543
544
    return result;
}

545
char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
546
    switch (cnstr) {
547
548
549
550
        case COMMON_SAMPLER_TYPE_DRY:         return 'd';
        case COMMON_SAMPLER_TYPE_TOP_K:       return 'k';
        case COMMON_SAMPLER_TYPE_TYPICAL_P:   return 'y';
        case COMMON_SAMPLER_TYPE_TOP_P:       return 'p';
551
        case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return 's';
552
553
554
555
        case COMMON_SAMPLER_TYPE_MIN_P:       return 'm';
        case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't';
        case COMMON_SAMPLER_TYPE_XTC:         return 'x';
        case COMMON_SAMPLER_TYPE_INFILL:      return 'i';
556
        case COMMON_SAMPLER_TYPE_PENALTIES:   return 'e';
557
558
559
        default : return '?';
    }
}
560

561
std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
562
    switch (cnstr) {
563
564
565
566
        case COMMON_SAMPLER_TYPE_DRY:         return "dry";
        case COMMON_SAMPLER_TYPE_TOP_K:       return "top_k";
        case COMMON_SAMPLER_TYPE_TYPICAL_P:   return "typ_p";
        case COMMON_SAMPLER_TYPE_TOP_P:       return "top_p";
567
        case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return "top_n_sigma";
568
569
570
571
        case COMMON_SAMPLER_TYPE_MIN_P:       return "min_p";
        case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature";
        case COMMON_SAMPLER_TYPE_XTC:         return "xtc";
        case COMMON_SAMPLER_TYPE_INFILL:      return "infill";
572
        case COMMON_SAMPLER_TYPE_PENALTIES:   return "penalties";
573
        default : return "";
574
    }
575
}
576

577
578
579
580
581
std::vector<common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names) {
    std::unordered_map<std::string, common_sampler_type> sampler_canonical_name_map {
        { "dry",         COMMON_SAMPLER_TYPE_DRY },
        { "top_k",       COMMON_SAMPLER_TYPE_TOP_K },
        { "top_p",       COMMON_SAMPLER_TYPE_TOP_P },
582
        { "top_n_sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
583
584
585
586
587
        { "typ_p",       COMMON_SAMPLER_TYPE_TYPICAL_P },
        { "min_p",       COMMON_SAMPLER_TYPE_MIN_P },
        { "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
        { "xtc",         COMMON_SAMPLER_TYPE_XTC },
        { "infill",      COMMON_SAMPLER_TYPE_INFILL },
588
        { "penalties",   COMMON_SAMPLER_TYPE_PENALTIES },
589
    };
590

591
592
    // since samplers names are written multiple ways
    // make it ready for both system names and input names
593
594
595
    std::unordered_map<std::string, common_sampler_type> sampler_alt_name_map {
        { "top-k",       COMMON_SAMPLER_TYPE_TOP_K },
        { "top-p",       COMMON_SAMPLER_TYPE_TOP_P },
596
        { "top-n-sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
597
598
599
600
601
602
603
        { "nucleus",     COMMON_SAMPLER_TYPE_TOP_P },
        { "typical-p",   COMMON_SAMPLER_TYPE_TYPICAL_P },
        { "typical",     COMMON_SAMPLER_TYPE_TYPICAL_P },
        { "typ-p",       COMMON_SAMPLER_TYPE_TYPICAL_P },
        { "typ",         COMMON_SAMPLER_TYPE_TYPICAL_P },
        { "min-p",       COMMON_SAMPLER_TYPE_MIN_P },
        { "temp",        COMMON_SAMPLER_TYPE_TEMPERATURE },
604
    };
605

606
    std::vector<common_sampler_type> samplers;
607
    samplers.reserve(names.size());
608

609
610
611
612
    for (const auto & name : names) {
        auto sampler = sampler_canonical_name_map.find(name);
        if (sampler != sampler_canonical_name_map.end()) {
            samplers.push_back(sampler->second);
613
614
615
616
617
618
619
            continue;
        }
        if (allow_alt_names) {
            sampler = sampler_alt_name_map.find(name);
            if (sampler != sampler_alt_name_map.end()) {
                samplers.push_back(sampler->second);
                continue;
620
621
            }
        }
622
        LOG_WRN("%s: unable to match sampler by name '%s'\n", __func__, name.c_str());
623
624
    }

625
    return samplers;
626
627
}

628
629
630
631
632
633
std::vector<common_sampler_type> common_sampler_types_from_chars(const std::string & chars) {
    std::unordered_map<char, common_sampler_type> sampler_name_map = {
        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_DRY),         COMMON_SAMPLER_TYPE_DRY },
        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_K),       COMMON_SAMPLER_TYPE_TOP_K },
        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TYPICAL_P),   COMMON_SAMPLER_TYPE_TYPICAL_P },
        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_P),       COMMON_SAMPLER_TYPE_TOP_P },
634
        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_N_SIGMA), COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
635
636
637
638
        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_MIN_P),       COMMON_SAMPLER_TYPE_MIN_P },
        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE },
        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC),         COMMON_SAMPLER_TYPE_XTC },
        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL),      COMMON_SAMPLER_TYPE_INFILL },
639
        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_PENALTIES),   COMMON_SAMPLER_TYPE_PENALTIES },
640
    };
641

642
    std::vector<common_sampler_type> samplers;
643
    samplers.reserve(chars.size());
644

645
646
647
648
    for (const auto & c : chars) {
        const auto sampler = sampler_name_map.find(c);
        if (sampler != sampler_name_map.end()) {
            samplers.push_back(sampler->second);
649
650
        } else {
            LOG_WRN("%s: unable to match sampler by char '%c'\n", __func__, c);
651
        }
652
    }
653
654

    return samplers;
655
}