sampling.cpp 17.4 KB
Newer Older
xuxzh1's avatar
init  
xuxzh1 committed
1
2
#include "sampling.h"

xuxzh1's avatar
update  
xuxzh1 committed
3
#include "common.h"
xuxzh1's avatar
init  
xuxzh1 committed
4

xuxzh1's avatar
update  
xuxzh1 committed
5
6
#include <cmath>
#include <unordered_map>
xuxzh1's avatar
init  
xuxzh1 committed
7

xuxzh1's avatar
update  
xuxzh1 committed
8
9
10
11
12
// the ring buffer works similarly to std::deque, but with a fixed capacity
// TODO: deduplicate with llama-impl.h
template<typename T>
struct ring_buffer {
    ring_buffer(size_t cap) : capacity(cap), data(cap) {}
xuxzh1's avatar
init  
xuxzh1 committed
13

xuxzh1's avatar
update  
xuxzh1 committed
14
15
16
    T & front() {
        if (sz == 0) {
            throw std::runtime_error("ring buffer is empty");
xuxzh1's avatar
init  
xuxzh1 committed
17
        }
xuxzh1's avatar
update  
xuxzh1 committed
18
19
        return data[first];
    }
xuxzh1's avatar
init  
xuxzh1 committed
20

xuxzh1's avatar
update  
xuxzh1 committed
21
22
23
    const T & front() const {
        if (sz == 0) {
            throw std::runtime_error("ring buffer is empty");
xuxzh1's avatar
init  
xuxzh1 committed
24
        }
xuxzh1's avatar
update  
xuxzh1 committed
25
26
        return data[first];
    }
xuxzh1's avatar
init  
xuxzh1 committed
27

xuxzh1's avatar
update  
xuxzh1 committed
28
29
30
    T & back() {
        if (sz == 0) {
            throw std::runtime_error("ring buffer is empty");
xuxzh1's avatar
init  
xuxzh1 committed
31
        }
xuxzh1's avatar
update  
xuxzh1 committed
32
        return data[pos];
xuxzh1's avatar
init  
xuxzh1 committed
33
34
    }

xuxzh1's avatar
update  
xuxzh1 committed
35
36
37
38
39
    const T & back() const {
        if (sz == 0) {
            throw std::runtime_error("ring buffer is empty");
        }
        return data[pos];
xuxzh1's avatar
init  
xuxzh1 committed
40
41
    }

xuxzh1's avatar
update  
xuxzh1 committed
42
43
44
45
46
47
48
49
50
    void push_back(const T & value) {
        if (sz == capacity) {
            // advance the start when buffer is full
            first = (first + 1) % capacity;
        } else {
            sz++;
        }
        data[pos] = value;
        pos = (pos + 1) % capacity;
xuxzh1's avatar
init  
xuxzh1 committed
51
52
    }

xuxzh1's avatar
update  
xuxzh1 committed
53
54
55
56
57
58
59
60
61
    T pop_front() {
        if (sz == 0) {
            throw std::runtime_error("ring buffer is empty");
        }
        T value = data[first];
        first = (first + 1) % capacity;
        sz--;
        return value;
    }
xuxzh1's avatar
init  
xuxzh1 committed
62

xuxzh1's avatar
update  
xuxzh1 committed
63
64
65
    const T & rat(size_t i) const {
        if (i >= sz) {
            throw std::runtime_error("ring buffer: index out of bounds");
xuxzh1's avatar
init  
xuxzh1 committed
66
        }
xuxzh1's avatar
update  
xuxzh1 committed
67
        return data[(first + sz - i - 1) % capacity];
xuxzh1's avatar
init  
xuxzh1 committed
68
69
    }

xuxzh1's avatar
update  
xuxzh1 committed
70
71
72
73
74
75
76
77
    std::vector<T> to_vector() const {
        std::vector<T> result;
        result.reserve(sz);
        for (size_t i = 0; i < sz; i++) {
            result.push_back(data[(first + i) % capacity]);
        }
        return result;
    }
xuxzh1's avatar
init  
xuxzh1 committed
78

xuxzh1's avatar
update  
xuxzh1 committed
79
80
81
82
83
    void clear() {
        // here only reset the status of the buffer
        sz = 0;
        first = 0;
        pos = 0;
xuxzh1's avatar
init  
xuxzh1 committed
84
85
    }

xuxzh1's avatar
update  
xuxzh1 committed
86
87
    bool empty() const {
        return sz == 0;
xuxzh1's avatar
init  
xuxzh1 committed
88
89
    }

xuxzh1's avatar
update  
xuxzh1 committed
90
91
    size_t size() const {
        return sz;
xuxzh1's avatar
init  
xuxzh1 committed
92
93
    }

xuxzh1's avatar
update  
xuxzh1 committed
94
95
96
97
98
99
    size_t capacity = 0;
    size_t sz = 0;
    size_t first = 0;
    size_t pos = 0;
    std::vector<T> data;
};
xuxzh1's avatar
init  
xuxzh1 committed
100

xuxzh1's avatar
update  
xuxzh1 committed
101
102
struct common_sampler {
    common_params_sampling params;
xuxzh1's avatar
init  
xuxzh1 committed
103

xuxzh1's avatar
update  
xuxzh1 committed
104
105
    struct llama_sampler * grmr;
    struct llama_sampler * chain;
xuxzh1's avatar
init  
xuxzh1 committed
106

xuxzh1's avatar
update  
xuxzh1 committed
107
    ring_buffer<llama_token> prev;
xuxzh1's avatar
init  
xuxzh1 committed
108

xuxzh1's avatar
update  
xuxzh1 committed
109
    std::vector<llama_token_data> cur;
xuxzh1's avatar
init  
xuxzh1 committed
110

xuxzh1's avatar
update  
xuxzh1 committed
111
    llama_token_data_array cur_p;
xuxzh1's avatar
init  
xuxzh1 committed
112

xuxzh1's avatar
update  
xuxzh1 committed
113
114
115
116
117
118
119
120
121
122
    void set_logits(struct llama_context * ctx, int idx) {
        const auto * logits = llama_get_logits_ith(ctx, idx);

        const int n_vocab = llama_n_vocab(llama_get_model(ctx));

        cur.resize(n_vocab);

        for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
            cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
        }
xuxzh1's avatar
init  
xuxzh1 committed
123

xuxzh1's avatar
update  
xuxzh1 committed
124
125
126
127
128
        cur_p = { cur.data(), cur.size(), -1, false };
    }
};

std::string common_params_sampling::print() const {
xuxzh1's avatar
init  
xuxzh1 committed
129
130
131
132
    char result[1024];

    snprintf(result, sizeof(result),
            "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
xuxzh1's avatar
update  
xuxzh1 committed
133
134
            "\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n"
            "\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, temp = %.3f\n"
xuxzh1's avatar
init  
xuxzh1 committed
135
            "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
xuxzh1's avatar
update  
xuxzh1 committed
136
137
138
139
            penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
            dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n,
            top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, temp,
            mirostat, mirostat_eta, mirostat_tau);
xuxzh1's avatar
init  
xuxzh1 committed
140
141
142
143

    return std::string(result);
}

xuxzh1's avatar
update  
xuxzh1 committed
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) {
    llama_sampler_chain_params lparams = llama_sampler_chain_default_params();

    lparams.no_perf = params.no_perf;

    auto * result = new common_sampler {
        /* .params = */ params,
        /* .grmr   = */ llama_sampler_init_grammar(model, params.grammar.c_str(), "root"),
        /* .chain  = */ llama_sampler_chain_init(lparams),
        /* .prev   = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
        /* .cur    = */ {},
        /* .cur_p  = */ {},
    };

    llama_sampler_chain_add(result->chain,
            llama_sampler_init_logit_bias(
                llama_n_vocab(model),
                params.logit_bias.size(),
                params.logit_bias.data()));

    llama_sampler_chain_add(result->chain,
            llama_sampler_init_penalties(
                llama_n_vocab  (model),
                llama_token_eos(model),
                llama_token_nl (model),
                params.penalty_last_n,
                params.penalty_repeat,
                params.penalty_freq,
                params.penalty_present,
                params.penalize_nl,
                params.ignore_eos));

xuxzh1's avatar
init  
xuxzh1 committed
176
    if (params.mirostat == 0) {
xuxzh1's avatar
update  
xuxzh1 committed
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
        for (const auto & cnstr : params.samplers) {
            switch (cnstr) {
                    case COMMON_SAMPLER_TYPE_DRY:
                    {
                        std::vector<const char*> c_breakers;
                        c_breakers.reserve(params.dry_sequence_breakers.size());
                        for (const auto& str : params.dry_sequence_breakers) {
                            c_breakers.push_back(str.c_str());
                        }

                        llama_sampler_chain_add(result->chain, llama_sampler_init_dry      (model, params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
                    }
                        break;
                case COMMON_SAMPLER_TYPE_TOP_K:
                    llama_sampler_chain_add(result->chain, llama_sampler_init_top_k    (params.top_k));
                    break;
                case COMMON_SAMPLER_TYPE_TOP_P:
                    llama_sampler_chain_add(result->chain, llama_sampler_init_top_p    (params.top_p, params.min_keep));
                    break;
                case COMMON_SAMPLER_TYPE_MIN_P:
                    llama_sampler_chain_add(result->chain, llama_sampler_init_min_p    (params.min_p, params.min_keep));
                    break;
                case COMMON_SAMPLER_TYPE_XTC:
                    llama_sampler_chain_add(result->chain, llama_sampler_init_xtc      (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
                    break;
                case COMMON_SAMPLER_TYPE_TYPICAL_P:
                    llama_sampler_chain_add(result->chain, llama_sampler_init_typical  (params.typ_p, params.min_keep));
                    break;
                case COMMON_SAMPLER_TYPE_TEMPERATURE:
                    llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
                    break;
                case COMMON_SAMPLER_TYPE_INFILL:
                    llama_sampler_chain_add(result->chain, llama_sampler_init_infill   (model));
                    break;
                default:
                    GGML_ASSERT(false && "unknown sampler type");
xuxzh1's avatar
init  
xuxzh1 committed
213
214
            }
        }
xuxzh1's avatar
update  
xuxzh1 committed
215
216
217
218
219
220
221
        llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
    } else if (params.mirostat == 1) {
        llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
        llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_n_vocab(model), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
    } else if (params.mirostat == 2) {
        llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
        llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
xuxzh1's avatar
init  
xuxzh1 committed
222
    } else {
xuxzh1's avatar
update  
xuxzh1 committed
223
        GGML_ASSERT(false && "unknown mirostat version");
xuxzh1's avatar
init  
xuxzh1 committed
224
225
226
227
228
    }

    return result;
}

xuxzh1's avatar
update  
xuxzh1 committed
229
230
231
232
233
234
235
void common_sampler_free(struct common_sampler * gsmpl) {
    if (gsmpl) {
        llama_sampler_free(gsmpl->grmr);

        llama_sampler_free(gsmpl->chain);

        delete gsmpl;
xuxzh1's avatar
init  
xuxzh1 committed
236
237
238
    }
}

xuxzh1's avatar
update  
xuxzh1 committed
239
240
241
242
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
    if (accept_grammar) {
        llama_sampler_accept(gsmpl->grmr, token);
    }
xuxzh1's avatar
init  
xuxzh1 committed
243

xuxzh1's avatar
update  
xuxzh1 committed
244
    llama_sampler_accept(gsmpl->chain, token);
xuxzh1's avatar
init  
xuxzh1 committed
245

xuxzh1's avatar
update  
xuxzh1 committed
246
    gsmpl->prev.push_back(token);
xuxzh1's avatar
init  
xuxzh1 committed
247
248
}

xuxzh1's avatar
update  
xuxzh1 committed
249
250
void common_sampler_reset(struct common_sampler * gsmpl) {
    llama_sampler_reset(gsmpl->grmr);
xuxzh1's avatar
init  
xuxzh1 committed
251

xuxzh1's avatar
update  
xuxzh1 committed
252
    llama_sampler_reset(gsmpl->chain);
xuxzh1's avatar
init  
xuxzh1 committed
253
254
}

xuxzh1's avatar
update  
xuxzh1 committed
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
    return new common_sampler {
        /* .params = */ gsmpl->params,
        /* .grmr   = */ llama_sampler_clone(gsmpl->grmr),
        /* .chain  = */ llama_sampler_clone(gsmpl->chain),
        /* .prev   = */ gsmpl->prev,
        /* .cur    = */ gsmpl->cur,
        /* .cur_p  = */ gsmpl->cur_p,
    };
}

void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl) {
    // TODO: measure grammar performance

    if (gsmpl) {
        llama_perf_sampler_print(gsmpl->chain);
    }
    if (ctx) {
        llama_perf_context_print(ctx);
xuxzh1's avatar
init  
xuxzh1 committed
274
275
276
    }
}

xuxzh1's avatar
update  
xuxzh1 committed
277
278
279
280
281
282
283
284
285
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
    gsmpl->set_logits(ctx, idx);

    auto & grmr  = gsmpl->grmr;
    auto & chain = gsmpl->chain;
    auto & cur_p = gsmpl->cur_p; // initialized by set_logits

    if (grammar_first) {
        llama_sampler_apply(grmr, &cur_p);
xuxzh1's avatar
init  
xuxzh1 committed
286
287
    }

xuxzh1's avatar
update  
xuxzh1 committed
288
    llama_sampler_apply(chain, &cur_p);
xuxzh1's avatar
init  
xuxzh1 committed
289

xuxzh1's avatar
update  
xuxzh1 committed
290
291
292
293
294
295
296
    GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");

    const llama_token id = cur_p.data[cur_p.selected].id;

    if (grammar_first) {
        return id;
    }
xuxzh1's avatar
init  
xuxzh1 committed
297

xuxzh1's avatar
update  
xuxzh1 committed
298
299
300
301
    // check if it the sampled token fits the grammar
    {
        llama_token_data       single_token_data       = { id, 1.0f, 0.0f };
        llama_token_data_array single_token_data_array = { &single_token_data, 1, -1, false };
xuxzh1's avatar
init  
xuxzh1 committed
302

xuxzh1's avatar
update  
xuxzh1 committed
303
        llama_sampler_apply(grmr, &single_token_data_array);
xuxzh1's avatar
init  
xuxzh1 committed
304

xuxzh1's avatar
update  
xuxzh1 committed
305
306
307
        const bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
        if (is_valid) {
            return id;
xuxzh1's avatar
init  
xuxzh1 committed
308
309
310
        }
    }

xuxzh1's avatar
update  
xuxzh1 committed
311
312
313
314
315
316
317
318
319
320
321
    // resampling:
    // if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain
    gsmpl->set_logits(ctx, idx);

    llama_sampler_apply(grmr,  &cur_p);
    llama_sampler_apply(chain, &cur_p);

    GGML_ASSERT(cur_p.selected != -1 && "no selected token during re-sampling - check your sampling configuration");

    return cur_p.data[cur_p.selected].id;
}
xuxzh1's avatar
init  
xuxzh1 committed
322

xuxzh1's avatar
update  
xuxzh1 committed
323
324
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first) {
    GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");
xuxzh1's avatar
init  
xuxzh1 committed
325

xuxzh1's avatar
update  
xuxzh1 committed
326
327
    std::vector<llama_token> result;
    result.reserve(idxs.size());
xuxzh1's avatar
init  
xuxzh1 committed
328

xuxzh1's avatar
update  
xuxzh1 committed
329
330
331
    size_t i = 0;
    for (; i < draft.size(); i++) {
        const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
xuxzh1's avatar
init  
xuxzh1 committed
332

xuxzh1's avatar
update  
xuxzh1 committed
333
        common_sampler_accept(gsmpl, id, true);
xuxzh1's avatar
init  
xuxzh1 committed
334

xuxzh1's avatar
update  
xuxzh1 committed
335
        result.push_back(id);
xuxzh1's avatar
init  
xuxzh1 committed
336

xuxzh1's avatar
update  
xuxzh1 committed
337
338
        if (draft[i] != id) {
            break;
xuxzh1's avatar
init  
xuxzh1 committed
339
340
341
        }
    }

xuxzh1's avatar
update  
xuxzh1 committed
342
343
344
345
346
347
348
    if (i == draft.size()) {
        const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);

        common_sampler_accept(gsmpl, id, true);

        result.push_back(id);
    }
xuxzh1's avatar
init  
xuxzh1 committed
349

xuxzh1's avatar
update  
xuxzh1 committed
350
    return result;
xuxzh1's avatar
init  
xuxzh1 committed
351
352
}

xuxzh1's avatar
update  
xuxzh1 committed
353
354
355
356
357
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) {
    std::vector<int> idxs(draft.size() + 1);
    for (size_t i = 0; i < idxs.size(); ++i) {
        idxs[i] = i;
    }
xuxzh1's avatar
init  
xuxzh1 committed
358

xuxzh1's avatar
update  
xuxzh1 committed
359
360
    return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first);
}
xuxzh1's avatar
init  
xuxzh1 committed
361

xuxzh1's avatar
update  
xuxzh1 committed
362
363
364
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
    return llama_sampler_get_seed(gsmpl->chain);
}
xuxzh1's avatar
init  
xuxzh1 committed
365

xuxzh1's avatar
update  
xuxzh1 committed
366
// helpers
xuxzh1's avatar
init  
xuxzh1 committed
367

xuxzh1's avatar
update  
xuxzh1 committed
368
369
370
371
372
373
374
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl) {
    return &gsmpl->cur_p;
}

llama_token common_sampler_last(const struct common_sampler * gsmpl) {
    return gsmpl->prev.rat(0);
}
xuxzh1's avatar
init  
xuxzh1 committed
375

xuxzh1's avatar
update  
xuxzh1 committed
376
377
std::string common_sampler_print(const struct common_sampler * gsmpl) {
    std::string result = "logits ";
xuxzh1's avatar
init  
xuxzh1 committed
378

xuxzh1's avatar
update  
xuxzh1 committed
379
380
381
    for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) {
        const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
        result += std::string("-> ") + llama_sampler_name(smpl) + " ";
xuxzh1's avatar
init  
xuxzh1 committed
382
383
    }

xuxzh1's avatar
update  
xuxzh1 committed
384
385
386
387
388
389
390
391
    return result;
}

std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx_main, int n) {
    n = std::min(n, (int) gsmpl->prev.size());

    if (n <= 0) {
        return "";
xuxzh1's avatar
init  
xuxzh1 committed
392
393
    }

xuxzh1's avatar
update  
xuxzh1 committed
394
395
396
397
398
399
400
401
402
    std::string result;
    result.reserve(8*n); // 8 is the average length of a token [citation needed], TODO: compute this from the vocab

    for (int i = n - 1; i >= 0; i--) {
        const llama_token id = gsmpl->prev.rat(i);

        GGML_ASSERT(id != LLAMA_TOKEN_NULL && "null token in the sampling history - should not happen");

        result += common_token_to_piece(ctx_main, id);
xuxzh1's avatar
init  
xuxzh1 committed
403
404
    }

xuxzh1's avatar
update  
xuxzh1 committed
405
406
    return result;
}
xuxzh1's avatar
init  
xuxzh1 committed
407

xuxzh1's avatar
update  
xuxzh1 committed
408
409
410
411
412
413
414
415
416
417
418
char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
    switch (cnstr) {
        case COMMON_SAMPLER_TYPE_DRY:         return 'd';
        case COMMON_SAMPLER_TYPE_TOP_K:       return 'k';
        case COMMON_SAMPLER_TYPE_TYPICAL_P:   return 'y';
        case COMMON_SAMPLER_TYPE_TOP_P:       return 'p';
        case COMMON_SAMPLER_TYPE_MIN_P:       return 'm';
        case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't';
        case COMMON_SAMPLER_TYPE_XTC:         return 'x';
        case COMMON_SAMPLER_TYPE_INFILL:      return 'i';
        default : return '?';
xuxzh1's avatar
init  
xuxzh1 committed
419
    }
xuxzh1's avatar
update  
xuxzh1 committed
420
}
xuxzh1's avatar
init  
xuxzh1 committed
421

xuxzh1's avatar
update  
xuxzh1 committed
422
423
424
425
426
427
428
429
430
431
432
433
434
std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
    switch (cnstr) {
        case COMMON_SAMPLER_TYPE_DRY:         return "dry";
        case COMMON_SAMPLER_TYPE_TOP_K:       return "top_k";
        case COMMON_SAMPLER_TYPE_TYPICAL_P:   return "typ_p";
        case COMMON_SAMPLER_TYPE_TOP_P:       return "top_p";
        case COMMON_SAMPLER_TYPE_MIN_P:       return "min_p";
        case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature";
        case COMMON_SAMPLER_TYPE_XTC:         return "xtc";
        case COMMON_SAMPLER_TYPE_INFILL:      return "infill";
        default : return "";
    }
}
xuxzh1's avatar
init  
xuxzh1 committed
435

xuxzh1's avatar
update  
xuxzh1 committed
436
437
438
439
440
441
442
443
444
445
446
std::vector<common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names) {
    std::unordered_map<std::string, common_sampler_type> sampler_canonical_name_map {
        { "dry",         COMMON_SAMPLER_TYPE_DRY },
        { "top_k",       COMMON_SAMPLER_TYPE_TOP_K },
        { "top_p",       COMMON_SAMPLER_TYPE_TOP_P },
        { "typ_p",       COMMON_SAMPLER_TYPE_TYPICAL_P },
        { "min_p",       COMMON_SAMPLER_TYPE_MIN_P },
        { "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
        { "xtc",         COMMON_SAMPLER_TYPE_XTC },
        { "infill",      COMMON_SAMPLER_TYPE_INFILL },
    };
xuxzh1's avatar
init  
xuxzh1 committed
447

xuxzh1's avatar
update  
xuxzh1 committed
448
449
450
451
452
453
454
455
456
457
458
459
460
    // since samplers names are written multiple ways
    // make it ready for both system names and input names
    std::unordered_map<std::string, common_sampler_type> sampler_alt_name_map {
        { "top-k",       COMMON_SAMPLER_TYPE_TOP_K },
        { "top-p",       COMMON_SAMPLER_TYPE_TOP_P },
        { "nucleus",     COMMON_SAMPLER_TYPE_TOP_P },
        { "typical-p",   COMMON_SAMPLER_TYPE_TYPICAL_P },
        { "typical",     COMMON_SAMPLER_TYPE_TYPICAL_P },
        { "typ-p",       COMMON_SAMPLER_TYPE_TYPICAL_P },
        { "typ",         COMMON_SAMPLER_TYPE_TYPICAL_P },
        { "min-p",       COMMON_SAMPLER_TYPE_MIN_P },
        { "temp",        COMMON_SAMPLER_TYPE_TEMPERATURE },
    };
xuxzh1's avatar
init  
xuxzh1 committed
461

xuxzh1's avatar
update  
xuxzh1 committed
462
463
464
465
466
467
468
469
470
471
472
473
    std::vector<common_sampler_type> samplers;
    samplers.reserve(names.size());

    for (const auto & name : names) {
        auto sampler = sampler_canonical_name_map.find(name);
        if (sampler != sampler_canonical_name_map.end()) {
            samplers.push_back(sampler->second);
        } else {
            if (allow_alt_names) {
                sampler = sampler_alt_name_map.find(name);
                if (sampler != sampler_alt_name_map.end()) {
                    samplers.push_back(sampler->second);
xuxzh1's avatar
init  
xuxzh1 committed
474
475
476
477
478
                }
            }
        }
    }

xuxzh1's avatar
update  
xuxzh1 committed
479
    return samplers;
xuxzh1's avatar
init  
xuxzh1 committed
480
481
}

xuxzh1's avatar
update  
xuxzh1 committed
482
483
484
485
486
487
488
489
490
491
492
std::vector<common_sampler_type> common_sampler_types_from_chars(const std::string & chars) {
    std::unordered_map<char, common_sampler_type> sampler_name_map = {
        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_DRY),         COMMON_SAMPLER_TYPE_DRY },
        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_K),       COMMON_SAMPLER_TYPE_TOP_K },
        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TYPICAL_P),   COMMON_SAMPLER_TYPE_TYPICAL_P },
        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_P),       COMMON_SAMPLER_TYPE_TOP_P },
        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_MIN_P),       COMMON_SAMPLER_TYPE_MIN_P },
        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE },
        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC),         COMMON_SAMPLER_TYPE_XTC },
        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL),      COMMON_SAMPLER_TYPE_INFILL },
    };
xuxzh1's avatar
init  
xuxzh1 committed
493

xuxzh1's avatar
update  
xuxzh1 committed
494
495
    std::vector<common_sampler_type> samplers;
    samplers.reserve(chars.size());
xuxzh1's avatar
init  
xuxzh1 committed
496

xuxzh1's avatar
update  
xuxzh1 committed
497
498
499
500
501
    for (const auto & c : chars) {
        const auto sampler = sampler_name_map.find(c);
        if (sampler != sampler_name_map.end()) {
            samplers.push_back(sampler->second);
        }
xuxzh1's avatar
init  
xuxzh1 committed
502
    }
xuxzh1's avatar
update  
xuxzh1 committed
503
504

    return samplers;
xuxzh1's avatar
init  
xuxzh1 committed
505
}