tokenization.h 4.47 KB
Newer Older
1
2
3
#ifndef CUBERT_TOKENIZATION_H
#define CUBERT_TOKENIZATION_H

liucong's avatar
liucong committed
4
#include <iostream>
5
6
#include <string>
#include <unordered_map>
liucong's avatar
liucong committed
7
#include <vector>
8
9
10

namespace cuBERT {

liucong's avatar
liucong committed
11
void load_vocab(const char* vocab_file, std::unordered_map<std::string, uint64_t>* vocab);
12
13
14
15
16
17

/**
 * Checks whether `chars` is a whitespace character.
 * @param c
 * @return
 */
liucong's avatar
liucong committed
18
bool _is_whitespace(int c);
19
20
21
22
23
24

/**
 * Checks whether `chars` is a control character.
 * @param c
 * @return
 */
liucong's avatar
liucong committed
25
bool _is_control(int c);
26
27
28
29
30
31

/**
 * Checks whether `chars` is a punctuation character.
 * @param cp
 * @return
 */
liucong's avatar
liucong committed
32
bool _is_punctuation(int cp);
33
34
35
36

/**
 * Runs basic tokenization (punctuation splitting, lower casing, etc.).
 */
liucong's avatar
liucong committed
37
38
class BasicTokenizer
{
39
    public:
liucong's avatar
liucong committed
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
    /**
     * Constructs a BasicTokenizer.
     * @param do_lower_case Whether to lower case the input.
     */
    explicit BasicTokenizer(bool do_lower_case = true) : do_lower_case(do_lower_case) {}

    BasicTokenizer(const BasicTokenizer& other) = delete;

    virtual ~BasicTokenizer() = default;

    /**
     * Tokenizes a piece of text.
     *
     * to_lower
     * _run_strip_accents Strips accents from a piece of text.
     * _clean_text Performs invalid character removal and whitespace cleanup on
     * text. _tokenize_chinese_chars Adds whitespace around any CJK character.
     * _run_split_on_punc Splits punctuation on a piece of text.
     * whitespace_tokenize Runs basic whitespace cleaning and splitting on a piece
     * of text.
     *
     * @param text
     * @param output_tokens
     */
    void tokenize(const char* text, std::vector<std::string>* output_tokens, size_t max_length);
65
66

    private:
liucong's avatar
liucong committed
67
    const bool do_lower_case;
68

liucong's avatar
liucong committed
69
70
71
72
73
74
75
    /**
     * Checks whether CP is the codepoint of a CJK character.
     * @param cp
     * @return
     */
    inline static bool _is_chinese_char(int cp);
};
76
77
78
79

/**
 * Runs WordPiece tokenziation.
 */
liucong's avatar
liucong committed
80
81
class WordpieceTokenizer
{
82
    public:
liucong's avatar
liucong committed
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
    explicit WordpieceTokenizer(std::unordered_map<std::string, uint64_t>* vocab,
                                std::string unk_token        = "[UNK]",
                                int max_input_chars_per_word = 200)
        : vocab(vocab), unk_token(unk_token), max_input_chars_per_word(max_input_chars_per_word)
    {
    }

    WordpieceTokenizer(const WordpieceTokenizer& other) = delete;

    virtual ~WordpieceTokenizer() = default;

    /**
     * Tokenizes a piece of text into its word pieces.
     *
     * This uses a greedy longest-match-first algorithm to perform tokenization
     * using the given vocabulary.
     *
     * For example:
     *   input = "unaffable"
     *   output = ["un", "##aff", "##able"]
     *
     * @param text A single token or whitespace separated tokens. This should have
     * already been passed through `BasicTokenizer.
     * @param output_tokens A list of wordpiece tokens.
     */
    void tokenize(const std::string& text, std::vector<std::string>* output_tokens);
109
110

    private:
liucong's avatar
liucong committed
111
112
113
114
    const std::unordered_map<std::string, uint64_t>* vocab;
    const std::string unk_token;
    const int max_input_chars_per_word;
};
115
116
117
118

/**
 * Runs end-to-end tokenziation.
 */
liucong's avatar
liucong committed
119
120
class FullTokenizer
{
121
    public:
liucong's avatar
liucong committed
122
123
124
125
126
127
128
129
130
131
132
133
134
    FullTokenizer(const char* vocab_file, bool do_lower_case = true)
    {
        vocab = new std::unordered_map<std::string, uint64_t>();
        load_vocab(vocab_file, vocab);
        basic_tokenizer     = new BasicTokenizer(do_lower_case);
        wordpiece_tokenizer = new WordpieceTokenizer(vocab);
    }

    ~FullTokenizer()
    {
        if(wordpiece_tokenizer != NULL)
        {
            wordpiece_tokenizer = NULL;
135
        }
liucong's avatar
liucong committed
136
        delete wordpiece_tokenizer;
137

liucong's avatar
liucong committed
138
139
140
        if(basic_tokenizer != NULL)
        {
            basic_tokenizer = NULL;
141
        }
liucong's avatar
liucong committed
142
        delete basic_tokenizer;
143

liucong's avatar
liucong committed
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
        if(vocab != NULL)
        {
            vocab = NULL;
        }
        delete vocab;
    }

    void tokenize(const char* text, std::vector<std::string>* output_tokens, size_t max_length);

    inline uint64_t convert_token_to_id(const std::string& token)
    {
        auto item = vocab->find(token);
        if(item == vocab->end())
        {
            std::cerr << "vocab missing key: " << token << std::endl;
            return 0;
        }
        else
        {
            return item->second;
164
        }
liucong's avatar
liucong committed
165
    }
166

liucong's avatar
liucong committed
167
    void convert_tokens_to_ids(const std::vector<std::string>& tokens, uint64_t* ids);
168
169

    private:
liucong's avatar
liucong committed
170
171
172
173
    std::unordered_map<std::string, uint64_t>* vocab;
    BasicTokenizer* basic_tokenizer;
    WordpieceTokenizer* wordpiece_tokenizer;
};
174

liucong's avatar
liucong committed
175
} // namespace cuBERT
176

liucong's avatar
liucong committed
177
#endif // CUBERT_TOKENIZATION_H