llama-grammar.h 6.52 KB
Newer Older
1
2
#pragma once

3
#include "llama.h"
4

5
#include <map>
6
#include <regex>
7
8
#include <string>
#include <vector>
9
#include <set>
10

11
struct llama_vocab;
12
13
14
15
16
17
18
19
20
21
struct ollama_vocab {
    std::map<uint32_t, std::string> token_to_piece_map;
    std::set<uint32_t> special_eog_ids;

    const std::string & token_to_piece(const uint32_t token) const;
    void add_token_pieces(const uint32_t* tokens, size_t n_tokens, const char** pieces);
    void set_eog_tokens(const uint32_t* tokens, size_t n_tokens);
    bool is_eog(const uint32_t token) const;

};
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74

// grammar element type
enum llama_gretype {
    // end of rule definition
    LLAMA_GRETYPE_END            = 0,

    // start of alternate definition for rule
    LLAMA_GRETYPE_ALT            = 1,

    // non-terminal element: reference to rule
    LLAMA_GRETYPE_RULE_REF       = 2,

    // terminal element: character (code point)
    LLAMA_GRETYPE_CHAR           = 3,

    // inverse char(s) ([^a], [^a-b] [^abc])
    LLAMA_GRETYPE_CHAR_NOT       = 4,

    // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
    // be an inclusive range ([a-z])
    LLAMA_GRETYPE_CHAR_RNG_UPPER = 5,

    // modifies a preceding LLAMA_GRETYPE_CHAR or
    // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
    LLAMA_GRETYPE_CHAR_ALT       = 6,

    // any character (.)
    LLAMA_GRETYPE_CHAR_ANY       = 7,
};

typedef struct llama_grammar_element {
    enum llama_gretype type;
    uint32_t           value; // Unicode code point or rule ID
} llama_grammar_element;

struct llama_partial_utf8 {
    uint32_t value;    // bit value so far (unshifted)
    int      n_remain; // num bytes remaining; -1 indicates invalid sequence
};

struct llama_grammar_candidate {
    size_t               index;
    const uint32_t     * code_points;
    llama_partial_utf8   partial_utf8;
};

using llama_grammar_rule  = std::vector<      llama_grammar_element>;
using llama_grammar_stack = std::vector<const llama_grammar_element *>;

using llama_grammar_rules      = std::vector<llama_grammar_rule>;
using llama_grammar_stacks     = std::vector<llama_grammar_stack>;
using llama_grammar_candidates = std::vector<llama_grammar_candidate>;

75
// TODO: remove, needed for tests atm
76
77
78
79
80
81
82
const llama_grammar_rules  & llama_grammar_get_rules (const struct llama_grammar * grammar);
      llama_grammar_stacks & llama_grammar_get_stacks(      struct llama_grammar * grammar);

// takes a set of possible pushdown stacks on a grammar, which are required to
// be positioned at a character range (see `llama_grammar_advance_stack`), and
// produces the N possible stacks if the given char is accepted at those
// positions
83
void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr);
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118

std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
        const llama_grammar_rules      & rules,
        const llama_grammar_stack      & stack,
        const llama_grammar_candidates & candidates);

struct llama_grammar_parser {
    std::map<std::string, uint32_t> symbol_ids;

    llama_grammar_rules rules;

    llama_grammar_stack c_rules() const;

    uint32_t get_symbol_id(const char * src, size_t len);
    uint32_t generate_symbol_id(const std::string & base_name);

    void add_rule(uint32_t rule_id, const llama_grammar_rule & rule);

    const char * parse_alternates(
            const char        * src,
            const std::string & rule_name,
            uint32_t            rule_id,
            bool                is_nested);

    const char * parse_sequence(
            const char         * src,
            const std::string  & rule_name,
            llama_grammar_rule & rule,
            bool               is_nested);

    const char * parse_rule(const char * src);

    bool parse(const char * src);
    void print(FILE * file);
};
119

120
121
122
123
124
struct llama_grammar_trigger_pattern {
    std::string pattern;
    std::regex  regex;
};

125
struct llama_grammar {
126
127
    // note: allow null vocab for testing (not great)
    const llama_vocab * vocab;
128
    const ollama_vocab * o_vocab;
129
130

    const llama_grammar_rules  rules;  // TODO: shared ptr
131
132
133
134
          llama_grammar_stacks stacks;

    // buffer for partially generated UTF-8 sequence from accepted tokens
    llama_partial_utf8 partial_utf8;
135
136
137
138
139
140
141
142

    // lazy grammars wait for trigger words or tokens before constraining the sampling.
    // we still have trigger_tokens for non-lazy grammars to force printing of special trigger tokens.
    // (useful e.g. for tool_choice=required)
    bool                     lazy             = false;
    bool                     awaiting_trigger = false; // Initialized to true for lazy grammars only
    std::string              trigger_buffer;           // Output buffered by lazy grammar. Will be cleared once trigger is found.
    std::vector<llama_token> trigger_tokens;           // Tokens that trigger a lazy grammar, or tokens to force printing of (even if special).
143
144
145
146
    std::vector<llama_grammar_trigger_pattern>
                             trigger_patterns;         // Regular expressions that trigger a lazy grammar. Must be a full match of the entire generated
                                                       // string, and the grammar will be given the string from the first match group onwards.

147
148
149
150
151
152
};

//
// internal API
//

153
// note: needed for tests (not great)
154
struct llama_grammar * llama_grammar_init_impl(
155
        const struct llama_vocab * vocab,
156
        const struct ollama_vocab * ollama_vocab,
157
158
159
160
        const llama_grammar_element ** rules,
        size_t n_rules,
        size_t start_rule_index);

161
162
struct llama_grammar * llama_grammar_init_impl(
        const struct llama_vocab * vocab,
163
        const struct ollama_vocab * ollama_vocab,
164
165
166
                      const char * grammar_str,
                      const char * grammar_root,
                              bool lazy,
167
168
                     const char ** trigger_patterns,
                            size_t num_trigger_patterns,
169
170
               const llama_token * trigger_tokens,
                            size_t num_trigger_tokens);
171
172
173

void llama_grammar_free_impl(struct llama_grammar * grammar);

174
struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar);
175

176
177
178
179
// TODO: move the API below as member functions of llama_grammar
void llama_grammar_apply_impl(
        const struct llama_grammar & grammar,
            llama_token_data_array * cur_p);
180

181
182
void llama_grammar_accept_impl(
              struct llama_grammar & grammar,
183
                       llama_token   token);
184
185
186
187

void llama_grammar_accept_str(
              struct llama_grammar & grammar,
                 const std::string & piece);