test-tokenizer-1-bpe.cpp 4.68 KB
Newer Older
xuxzh1's avatar
init  
xuxzh1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#include "llama.h"
#include "common.h"
#include "unicode.h"
#include "console.h"

#include <cassert>
#include <codecvt>
#include <cstdio>
#include <cstring>
#include <locale>
#include <string>
#include <thread>
#include <vector>
#include <atomic>

int main(int argc, char **argv) {
    if (argc < 2 || argc > 3) {
        fprintf(stderr, "Usage: %s <vocab-file> [--ignore-merges]\n", argv[0]);
        return 1;
    }

    const std::string fname = argv[1];
    bool ignore_merges = false;
    if (argc == 3) {
        if (std::strcmp(argv[2], "--ignore-merges") != 0) {
            fprintf(stderr, "Usage: %s <vocab-file> [--ignore-merges]\n", argv[0]);
            return 1;
        }
        ignore_merges = true;
    }

    fprintf(stderr, "%s : reading vocab from: '%s'\n", __func__, fname.c_str());

    if (ignore_merges) {
        fprintf(stderr, "%s : ignoring merges for tokens inside vocab\n", __func__);
    }

    llama_model * model;
    llama_context * ctx;

    llama_backend_init();

    // load the vocab
    {
        auto mparams = llama_model_default_params();

        mparams.vocab_only = true;

        model = llama_load_model_from_file(fname.c_str(), mparams);

        if (model == NULL) {
            fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
            return 1;
        }

        auto cparams = llama_context_default_params();

        ctx = llama_new_context_with_model(model, cparams);

        if (ctx == NULL) {
            fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
            llama_free_model(model);
            return 1;
        }
    }

    //GGML_ASSERT(llama_vocab_type(model) == LLAMA_VOCAB_TYPE_BPE);
    if (llama_vocab_type(model) != LLAMA_VOCAB_TYPE_BPE) {
        return 99;
    }

#ifdef _WIN32
    // We need this for unicode console support
    console::init(false, false);
    atexit([]() { console::cleanup(); });
#endif

    const int n_vocab = llama_n_vocab(model);

    for (int i = 0; i < n_vocab; ++i) {
xuxzh1's avatar
update  
xuxzh1 committed
81
        std::string str = common_detokenize(ctx, std::vector<int>(1, i));
xuxzh1's avatar
init  
xuxzh1 committed
82
83
        try {
            auto cps = unicode_cpts_from_utf8(str);
xuxzh1's avatar
update  
xuxzh1 committed
84
            std::vector<llama_token> tokens = common_tokenize(ctx, str, false, true);
xuxzh1's avatar
init  
xuxzh1 committed
85
86
87
88
89
90
91
92
93
94
95
96
            if (ignore_merges && tokens.size() > 1) {
                fprintf(stderr,
                        "%s : error: token %d detokenizes to '%s'(%zu) but "
                        "tokenization of this to multiple tokens: [",
                        __func__, i, str.c_str(), str.length());
                fprintf(stderr, "%d", tokens[0]);
                for (size_t i = 1; i < tokens.size(); i++) {
                    fprintf(stderr, ", %d", tokens[i]);
                }
                fprintf(stderr, "]\n");
                return 2;
            }
xuxzh1's avatar
update  
xuxzh1 committed
97
            std::string check = common_detokenize(ctx, tokens);
xuxzh1's avatar
init  
xuxzh1 committed
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
            if (check != str) {
                fprintf(stderr, "%s : error: token %d detokenizes to '%s'(%zu) but tokenization of this detokenizes to '%s'(%zu)\n",
                    __func__, i, str.c_str(), str.length(), check.c_str(), check.length());
                return 2;
            }
        }
        catch (const std::invalid_argument &) {
            //fprintf(stderr, "%s : info: utf8 conversion %d '%s'\n", __func__, i, str.c_str());
        }
    }

    // unicode
    {
        const int nthread = std::thread::hardware_concurrency();

        std::vector<std::thread> threads(nthread);

        std::atomic_int errcode = {};

        for (int i = 0; i < nthread; ++i) {
            threads[i] = std::thread([i, nthread, ctx, &errcode]() {
                for (uint32_t cp = i; !errcode && cp < 0x00110000; cp += nthread) {
                    if ((0x0000D800 <= cp && cp <= 0x0000DFFF) ||  // surrogates \p{Cs}
                        (0x00040000 <= cp && cp <= 0x000E0000)) {  // undefined  \p{Cn}
                        continue;
                    }

                    std::string str = unicode_cpt_to_utf8(cp);
xuxzh1's avatar
update  
xuxzh1 committed
126
127
                    std::vector<llama_token> tokens = common_tokenize(ctx, str, false);
                    std::string check = common_detokenize(ctx, tokens);
xuxzh1's avatar
init  
xuxzh1 committed
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
                    if (cp != 9601 && str != check) {
                        fprintf(stderr, "error: codepoint 0x%x detokenizes to '%s'(%zu) instead of '%s'(%zu)\n",
                                cp, check.c_str(), check.length(), str.c_str(), str.length());
                        errcode = 3;
                    }
                }
            });
        }

        for (auto & t : threads) {
            t.join();
        }

        if (errcode) {
            return errcode;
        }
    }

    llama_free_model(model);
    llama_free(ctx);

    llama_backend_free();

    return 0;
}