gbnf-validator.cpp 3.52 KB
Newer Older
xuxzh1's avatar
init  
xuxzh1 committed
1
#include "unicode.h"
xuxzh1's avatar
update  
xuxzh1 committed
2
#include "llama-grammar.h"
xuxzh1's avatar
init  
xuxzh1 committed
3
4
5
6
7
8
9
10

#include <cstdio>
#include <cstdlib>
#include <sstream>
#include <fstream>
#include <string>
#include <vector>

xuxzh1's avatar
update  
xuxzh1 committed
11
12
static bool llama_grammar_validate(struct llama_grammar * grammar, const std::string & input_str, size_t & error_pos, std::string & error_msg) {
    const auto cpts = unicode_cpts_from_utf8(input_str);
xuxzh1's avatar
init  
xuxzh1 committed
13
14

    const llama_grammar_rules  & rules      = llama_grammar_get_rules (grammar);
xuxzh1's avatar
update  
xuxzh1 committed
15
          llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
xuxzh1's avatar
init  
xuxzh1 committed
16
17

    size_t pos = 0;
xuxzh1's avatar
update  
xuxzh1 committed
18
19
    for (const auto & cpt : cpts) {
        const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy
xuxzh1's avatar
init  
xuxzh1 committed
20

xuxzh1's avatar
update  
xuxzh1 committed
21
        llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur);
xuxzh1's avatar
init  
xuxzh1 committed
22

xuxzh1's avatar
update  
xuxzh1 committed
23
        if (stacks_cur.empty()) {
xuxzh1's avatar
init  
xuxzh1 committed
24
            error_pos = pos;
xuxzh1's avatar
update  
xuxzh1 committed
25
26
            error_msg = "Unexpected character '" + unicode_cpt_to_utf8(cpt) + "'";
            stacks_cur = stacks_prev;
xuxzh1's avatar
init  
xuxzh1 committed
27
28
29
30
31
            return false;
        }
        ++pos;
    }

xuxzh1's avatar
update  
xuxzh1 committed
32
    for (const auto & stack : stacks_cur) {
xuxzh1's avatar
init  
xuxzh1 committed
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
        if (stack.empty()) {
            return true;
        }
    }

    error_pos = pos;
    error_msg = "Unexpected end of input";
    return false;
}

static void print_error_message(const std::string & input_str, size_t error_pos, const std::string & error_msg) {
    fprintf(stdout, "Input string is invalid according to the grammar.\n");
    fprintf(stdout, "Error: %s at position %zu\n", error_msg.c_str(), error_pos);
    fprintf(stdout, "\n");
    fprintf(stdout, "Input string:\n");
    fprintf(stdout, "%s", input_str.substr(0, error_pos).c_str());
    if (error_pos < input_str.size()) {
        fprintf(stdout, "\033[1;31m%c", input_str[error_pos]);
        if (error_pos+1 < input_str.size()) {
            fprintf(stdout, "\033[0;31m%s", input_str.substr(error_pos+1).c_str());
        }
        fprintf(stdout, "\033[0m\n");
    }
}

int main(int argc, char** argv) {
    if (argc != 3) {
        fprintf(stdout, "Usage: %s <grammar_filename> <input_filename>\n", argv[0]);
        return 1;
    }

    const std::string grammar_filename = argv[1];
    const std::string input_filename = argv[2];

    // Read the GBNF grammar file
    FILE* grammar_file = fopen(grammar_filename.c_str(), "r");
    if (!grammar_file) {
        fprintf(stdout, "Failed to open grammar file: %s\n", grammar_filename.c_str());
        return 1;
    }

    std::string grammar_str;
    {
        std::ifstream grammar_file(grammar_filename);
        GGML_ASSERT(grammar_file.is_open() && "Failed to open grammar file");
        std::stringstream buffer;
        buffer << grammar_file.rdbuf();
        grammar_str = buffer.str();
    }

xuxzh1's avatar
update  
xuxzh1 committed
83
    llama_grammar * grammar = llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root");
xuxzh1's avatar
init  
xuxzh1 committed
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
    if (grammar == nullptr) {
        throw std::runtime_error("Failed to initialize llama_grammar");
    }
    // Read the input file
    std::string input_str;
    {
        std::ifstream input_file(input_filename);
        GGML_ASSERT(input_file.is_open() && "Failed to open input file");
        std::stringstream buffer;
        buffer << input_file.rdbuf();
        input_str = buffer.str();
    }

    // Validate the input string against the grammar
    size_t error_pos;
    std::string error_msg;
xuxzh1's avatar
update  
xuxzh1 committed
100
    bool is_valid = llama_grammar_validate(grammar, input_str, error_pos, error_msg);
xuxzh1's avatar
init  
xuxzh1 committed
101
102
103
104
105
106
107
108

    if (is_valid) {
        fprintf(stdout, "Input string is valid according to the grammar.\n");
    } else {
        print_error_message(input_str, error_pos, error_msg);
    }

    // Clean up
xuxzh1's avatar
update  
xuxzh1 committed
109
    llama_grammar_free_impl(grammar);
xuxzh1's avatar
init  
xuxzh1 committed
110
111
112

    return 0;
}