llama-grammar.h 4.54 KB
Newer Older
xuxzh1's avatar
init  
xuxzh1 committed
1
2
3
4
#pragma once

#include "llama-impl.h"

xuxzh1's avatar
update  
xuxzh1 committed
5
6
#include <map>

xuxzh1's avatar
init  
xuxzh1 committed
7
struct llama_vocab;
xuxzh1's avatar
update  
xuxzh1 committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107

// grammar element type
enum llama_gretype {
    // end of rule definition
    LLAMA_GRETYPE_END            = 0,

    // start of alternate definition for rule
    LLAMA_GRETYPE_ALT            = 1,

    // non-terminal element: reference to rule
    LLAMA_GRETYPE_RULE_REF       = 2,

    // terminal element: character (code point)
    LLAMA_GRETYPE_CHAR           = 3,

    // inverse char(s) ([^a], [^a-b] [^abc])
    LLAMA_GRETYPE_CHAR_NOT       = 4,

    // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
    // be an inclusive range ([a-z])
    LLAMA_GRETYPE_CHAR_RNG_UPPER = 5,

    // modifies a preceding LLAMA_GRETYPE_CHAR or
    // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
    LLAMA_GRETYPE_CHAR_ALT       = 6,

    // any character (.)
    LLAMA_GRETYPE_CHAR_ANY       = 7,
};

typedef struct llama_grammar_element {
    enum llama_gretype type;
    uint32_t           value; // Unicode code point or rule ID
} llama_grammar_element;

struct llama_partial_utf8 {
    uint32_t value;    // bit value so far (unshifted)
    int      n_remain; // num bytes remaining; -1 indicates invalid sequence
};

struct llama_grammar_candidate {
    size_t               index;
    const uint32_t     * code_points;
    llama_partial_utf8   partial_utf8;
};

using llama_grammar_rule  = std::vector<      llama_grammar_element>;
using llama_grammar_stack = std::vector<const llama_grammar_element *>;

using llama_grammar_rules      = std::vector<llama_grammar_rule>;
using llama_grammar_stacks     = std::vector<llama_grammar_stack>;
using llama_grammar_candidates = std::vector<llama_grammar_candidate>;

const llama_grammar_rules  & llama_grammar_get_rules (const struct llama_grammar * grammar);
      llama_grammar_stacks & llama_grammar_get_stacks(      struct llama_grammar * grammar);

// takes a set of possible pushdown stacks on a grammar, which are required to
// be positioned at a character range (see `llama_grammar_advance_stack`), and
// produces the N possible stacks if the given char is accepted at those
// positions
void llama_grammar_accept(
        const llama_grammar_rules  & rules,
        const llama_grammar_stacks & stacks,
                          uint32_t   chr,
              llama_grammar_stacks & stacks_new);

std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
        const llama_grammar_rules      & rules,
        const llama_grammar_stack      & stack,
        const llama_grammar_candidates & candidates);

struct llama_grammar_parser {
    std::map<std::string, uint32_t> symbol_ids;

    llama_grammar_rules rules;

    llama_grammar_stack c_rules() const;

    uint32_t get_symbol_id(const char * src, size_t len);
    uint32_t generate_symbol_id(const std::string & base_name);

    void add_rule(uint32_t rule_id, const llama_grammar_rule & rule);

    const char * parse_alternates(
            const char        * src,
            const std::string & rule_name,
            uint32_t            rule_id,
            bool                is_nested);

    const char * parse_sequence(
            const char         * src,
            const std::string  & rule_name,
            llama_grammar_rule & rule,
            bool               is_nested);

    const char * parse_rule(const char * src);

    bool parse(const char * src);
    void print(FILE * file);
};
xuxzh1's avatar
init  
xuxzh1 committed
108
109

struct llama_grammar {
xuxzh1's avatar
update  
xuxzh1 committed
110
111
112
113
    // note: allow null vocab for testing (not great)
    const llama_vocab * vocab;

    const llama_grammar_rules  rules;  // TODO: shared ptr
xuxzh1's avatar
init  
xuxzh1 committed
114
115
116
117
118
119
120
121
122
123
          llama_grammar_stacks stacks;

    // buffer for partially generated UTF-8 sequence from accepted tokens
    llama_partial_utf8 partial_utf8;
};

//
// internal API
//

xuxzh1's avatar
update  
xuxzh1 committed
124
// note: needed for tests (not great)
xuxzh1's avatar
init  
xuxzh1 committed
125
struct llama_grammar * llama_grammar_init_impl(
xuxzh1's avatar
update  
xuxzh1 committed
126
127
128
129
130
131
        const struct llama_vocab * vocab,
        const llama_grammar_element ** rules,
        size_t n_rules,
        size_t start_rule_index);

struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root);
xuxzh1's avatar
init  
xuxzh1 committed
132
133
134

void llama_grammar_free_impl(struct llama_grammar * grammar);

xuxzh1's avatar
update  
xuxzh1 committed
135
struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar);
xuxzh1's avatar
init  
xuxzh1 committed
136

xuxzh1's avatar
update  
xuxzh1 committed
137
138
139
140
// TODO: move the API below as member functions of llama_grammar
void llama_grammar_apply_impl(
        const struct llama_grammar & grammar,
            llama_token_data_array * cur_p);
xuxzh1's avatar
init  
xuxzh1 committed
141

xuxzh1's avatar
update  
xuxzh1 committed
142
143
void llama_grammar_accept_impl(
              struct llama_grammar & grammar,
xuxzh1's avatar
init  
xuxzh1 committed
144
                       llama_token   token);