llama-sampling.h 1.46 KB
Newer Older
xuxzh1's avatar
init  
xuxzh1 committed
1
2
#pragma once

xuxzh1's avatar
update  
xuxzh1 committed
3
// TODO: rename llama-sampling.h/.cpp to llama-sampler.h/.cpp ?
xuxzh1's avatar
init  
xuxzh1 committed
4

xuxzh1's avatar
update  
xuxzh1 committed
5
#include "llama-grammar.h"
xuxzh1's avatar
init  
xuxzh1 committed
6

xuxzh1's avatar
update  
xuxzh1 committed
7
8
struct llama_vocab;
struct llama_grammar;
xuxzh1's avatar
init  
xuxzh1 committed
9

xuxzh1's avatar
update  
xuxzh1 committed
10
// sampler chain
xuxzh1's avatar
init  
xuxzh1 committed
11

xuxzh1's avatar
update  
xuxzh1 committed
12
13
struct llama_sampler_chain {
    llama_sampler_chain_params params;
xuxzh1's avatar
init  
xuxzh1 committed
14

xuxzh1's avatar
update  
xuxzh1 committed
15
16
17
    std::vector<struct llama_sampler *> samplers;

    // timing
xuxzh1's avatar
init  
xuxzh1 committed
18

xuxzh1's avatar
update  
xuxzh1 committed
19
20
21
22
    mutable int64_t t_sample_us;

    mutable int32_t n_sample;
};
xuxzh1's avatar
init  
xuxzh1 committed
23

xuxzh1's avatar
update  
xuxzh1 committed
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
struct llama_sampler * llama_sampler_init_grammar_impl(
        const struct llama_vocab & vocab,
                      const char * grammar_str,
                      const char * grammar_root);

struct llama_sampler * llama_sampler_init_infill_impl(
        const struct llama_vocab & vocab);

struct llama_sampler * llama_sampler_init_dry_impl(
        const struct llama_vocab &  vocab,
                         int32_t    context_size,
                           float    dry_multiplier,
                           float    dry_base,
                         int32_t    dry_allowed_length,
                         int32_t    dry_penalty_last_n,
                      const char ** seq_breakers,
                          size_t    num_breakers);

struct llama_sampler * llama_sampler_init_dry_testing(
                         int32_t   context_size,
                           float   dry_multiplier,
                           float   dry_base,
                         int32_t   dry_allowed_length,
                         int32_t   dry_penalty_last_n,
  const std::vector<std::vector<llama_token>>& seq_breakers);