sampling_ext.cpp 4.2 KB
Newer Older
1
2
3
// TODO: this is a temporary wrapper to allow calling C++ code from CGo
#include "sampling.h"
#include "sampling_ext.h"
4
#include "json-schema-to-grammar.h"
5
6
7
#include "llama.h"
#include "llama-model.h"
#include "llama-model-loader.h"
8
#include "llama-grammar.h"
9

10
11
12
struct common_sampler *common_sampler_cinit(const struct llama_model *model, struct common_sampler_cparams *params) {
    try {
        common_params_sampling sparams;
Jesse Gross's avatar
Jesse Gross committed
13
14
15
16
17
18
19
20
21
22
23
        sparams.top_k = params->top_k;
        sparams.top_p = params->top_p;
        sparams.min_p = params->min_p;
        sparams.typ_p = params->typical_p;
        sparams.temp = params->temp;
        sparams.penalty_last_n = params->penalty_last_n;
        sparams.penalty_repeat = params->penalty_repeat;
        sparams.penalty_freq = params->penalty_freq;
        sparams.penalty_present = params->penalty_present;
        sparams.seed = params->seed;
        sparams.grammar = params->grammar;
24
25
26
27
        sparams.xtc_probability = 0.0;
        sparams.xtc_threshold = 0.5;
        return common_sampler_init(model, sparams);
    } catch (const std::exception &err) {
Jesse Gross's avatar
Jesse Gross committed
28
29
        return nullptr;
    }
30
31
}

32
33
void common_sampler_cfree(struct common_sampler *sampler) {
    common_sampler_free(sampler);
34
35
}

36
37
void common_sampler_creset(struct common_sampler *sampler) {
    common_sampler_reset(sampler);
38
39
}

40
41
void common_sampler_caccept(struct common_sampler *sampler, llama_token id, bool apply_grammar) {
    common_sampler_accept(sampler, id, apply_grammar);
42
43
}

44
45
llama_token common_sampler_csample(struct common_sampler *sampler, struct llama_context *ctx, int idx) {
    return common_sampler_sample(sampler, ctx, idx);
46
}
47
48
49
50
51

int schema_to_grammar(const char *json_schema, char *grammar, size_t max_len)
{
    try
    {
52
        nlohmann::ordered_json schema = nlohmann::ordered_json::parse(json_schema);
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
        std::string grammar_str = json_schema_to_grammar(schema);
        size_t len = grammar_str.length();
        if (len >= max_len)
        {
            len = max_len - 1;
        }
        strncpy(grammar, grammar_str.c_str(), len);
        return len;
    }
    catch (const std::exception &e)
    {
        strncpy(grammar, "", max_len - 1);
        return 0;
    }
}
68
69
70
71
72
73

struct llama_vocab * llama_load_vocab_from_file(const char * fname) {
    llama_vocab * vocab = new llama_vocab();
    try {
        const auto kv = LLM_KV(LLM_ARCH_UNKNOWN);
        std::vector<std::string> splits = {};
74
        llama_model_loader ml(std::string(fname), splits, false, false, nullptr, nullptr);
75
76
77
78
79
80
81
82
83
84
85
86
        vocab->load(ml, kv);
    } catch (const std::exception & err) {
        LLAMA_LOG_ERROR("%s: error loading model: %s\n", __func__, err.what());
        return nullptr;
    }

    return vocab;
}

void llama_free_vocab(struct llama_vocab * vocab) {
    delete vocab;
}
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
struct llama_grammar *grammar_init(char* grammar, uint32_t* tokens, size_t n_tokens, const char** pieces, uint32_t* eog_tokens, size_t n_eog_tokens) {
    try {
        if (grammar == nullptr) {
            LLAMA_LOG_ERROR("%s: null grammar input\n", __func__);
            return nullptr;
        }

        ollama_vocab *vocab = new ollama_vocab();
        vocab->set_eog_tokens(eog_tokens, n_eog_tokens);
        vocab->add_token_pieces(tokens, n_tokens, pieces);
        
        struct llama_grammar *g = llama_grammar_init_impl(nullptr, vocab, grammar, "root", false, nullptr, 0, nullptr, 0);
        if (g == nullptr) {
            LLAMA_LOG_ERROR("%s: failed to initialize grammar\n", __func__);
            delete vocab;
            return nullptr;
        }
        return g;

    } catch (const std::exception& e) {
        LLAMA_LOG_ERROR("%s: exception during initialization: %s\n", __func__, e.what());
        return nullptr;
    }
}

void grammar_free(struct llama_grammar *g) {
    if (g != nullptr) {
        if (g->vocab != nullptr) {
            delete g->vocab;
        }
        llama_grammar_free_impl(g);
    }
}

void grammar_apply(struct llama_grammar *g, struct llama_token_data_array *tokens) {
    if (g == nullptr || tokens == nullptr) {
        LLAMA_LOG_ERROR("%s: null grammar or tokens input\n", __func__);
        return;
    }
    llama_grammar_apply_impl(*g, tokens);
}


void grammar_accept(struct llama_grammar *g, llama_token id) {
    llama_grammar_accept_impl(*g, id);
}