test_tokenization_bert.py 13.9 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
# coding=utf-8
Sylvain Gugger's avatar
Sylvain Gugger committed
2
# Copyright 2020 The HuggingFace Team. All rights reserved.
thomwolf's avatar
thomwolf committed
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Aymeric Augustin's avatar
Aymeric Augustin committed
15

thomwolf's avatar
thomwolf committed
16
17

import os
18
import unittest
thomwolf's avatar
thomwolf committed
19

20
from transformers import BertTokenizerFast
Sylvain Gugger's avatar
Sylvain Gugger committed
21
from transformers.models.bert.tokenization_bert import (
Aymeric Augustin's avatar
Aymeric Augustin committed
22
    VOCAB_FILES_NAMES,
23
24
25
26
27
28
29
    BasicTokenizer,
    BertTokenizer,
    WordpieceTokenizer,
    _is_control,
    _is_punctuation,
    _is_whitespace,
)
Sylvain Gugger's avatar
Sylvain Gugger committed
30
from transformers.testing_utils import require_tokenizers, slow
thomwolf's avatar
thomwolf committed
31

Yih-Dar's avatar
Yih-Dar committed
32
from ...test_tokenization_common import TokenizerTesterMixin, filter_non_english
thomwolf's avatar
thomwolf committed
33

34

35
@require_tokenizers
36
class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
37
    tokenizer_class = BertTokenizer
38
    rust_tokenizer_class = BertTokenizerFast
Anthony MOI's avatar
Anthony MOI committed
39
    test_rust_tokenizer = True
40
    space_between_special_tokens = True
41
    from_pretrained_filter = filter_non_english
42
43

    def setUp(self):
Julien Chaumond's avatar
Julien Chaumond committed
44
        super().setUp()
thomwolf's avatar
thomwolf committed
45

46
        vocab_tokens = [
47
48
49
            "[UNK]",
            "[CLS]",
            "[SEP]",
50
51
            "[PAD]",
            "[MASK]",
52
53
54
55
56
57
58
59
60
61
            "want",
            "##want",
            "##ed",
            "wa",
            "un",
            "runn",
            "##ing",
            ",",
            "low",
            "lowest",
62
        ]
63
64
        self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
        with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
65
            vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
thomwolf's avatar
thomwolf committed
66

67
    def get_input_output_texts(self, tokenizer):
68
69
        input_text = "UNwant\u00E9d,running"
        output_text = "unwanted, running"
70
        return input_text, output_text
71

72
    def test_full_tokenizer(self):
thomwolf's avatar
thomwolf committed
73
        tokenizer = self.tokenizer_class(self.vocab_file)
thomwolf's avatar
thomwolf committed
74

75
        tokens = tokenizer.tokenize("UNwant\u00E9d,running")
76
        self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
77
        self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [9, 6, 7, 12, 10, 11])
78

Anthony MOI's avatar
Anthony MOI committed
79
80
81
82
83
    def test_rust_and_python_full_tokenizers(self):
        if not self.test_rust_tokenizer:
            return

        tokenizer = self.get_tokenizer()
Funtowicz Morgan's avatar
Funtowicz Morgan committed
84
        rust_tokenizer = self.get_rust_tokenizer()
Anthony MOI's avatar
Anthony MOI committed
85

86
        sequence = "UNwant\u00E9d,running"
Anthony MOI's avatar
Anthony MOI committed
87
88
89
90
91
92

        tokens = tokenizer.tokenize(sequence)
        rust_tokens = rust_tokenizer.tokenize(sequence)
        self.assertListEqual(tokens, rust_tokens)

        ids = tokenizer.encode(sequence, add_special_tokens=False)
Funtowicz Morgan's avatar
Funtowicz Morgan committed
93
        rust_ids = rust_tokenizer.encode(sequence, add_special_tokens=False)
Anthony MOI's avatar
Anthony MOI committed
94
95
96
97
98
99
100
        self.assertListEqual(ids, rust_ids)

        rust_tokenizer = self.get_rust_tokenizer()
        ids = tokenizer.encode(sequence)
        rust_ids = rust_tokenizer.encode(sequence)
        self.assertListEqual(ids, rust_ids)

101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
        # With lower casing
        tokenizer = self.get_tokenizer(do_lower_case=True)
        rust_tokenizer = self.get_rust_tokenizer(do_lower_case=True)

        sequence = "UNwant\u00E9d,running"

        tokens = tokenizer.tokenize(sequence)
        rust_tokens = rust_tokenizer.tokenize(sequence)
        self.assertListEqual(tokens, rust_tokens)

        ids = tokenizer.encode(sequence, add_special_tokens=False)
        rust_ids = rust_tokenizer.encode(sequence, add_special_tokens=False)
        self.assertListEqual(ids, rust_ids)

        rust_tokenizer = self.get_rust_tokenizer()
        ids = tokenizer.encode(sequence)
        rust_ids = rust_tokenizer.encode(sequence)
        self.assertListEqual(ids, rust_ids)

120
    def test_chinese(self):
thomwolf's avatar
thomwolf committed
121
        tokenizer = BasicTokenizer()
122

123
        self.assertListEqual(tokenizer.tokenize("ah\u535A\u63A8zz"), ["ah", "\u535A", "\u63A8", "zz"])
124

125
    def test_basic_tokenizer_lower(self):
thomwolf's avatar
thomwolf committed
126
        tokenizer = BasicTokenizer(do_lower_case=True)
thomwolf's avatar
thomwolf committed
127

128
        self.assertListEqual(
129
130
131
            tokenizer.tokenize(" \tHeLLo!how  \n Are yoU?  "), ["hello", "!", "how", "are", "you", "?"]
        )
        self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["hello"])
thomwolf's avatar
thomwolf committed
132

133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
    def test_basic_tokenizer_lower_strip_accents_false(self):
        tokenizer = BasicTokenizer(do_lower_case=True, strip_accents=False)

        self.assertListEqual(
            tokenizer.tokenize(" \tHäLLo!how  \n Are yoU?  "), ["hällo", "!", "how", "are", "you", "?"]
        )
        self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["h\u00E9llo"])

    def test_basic_tokenizer_lower_strip_accents_true(self):
        tokenizer = BasicTokenizer(do_lower_case=True, strip_accents=True)

        self.assertListEqual(
            tokenizer.tokenize(" \tHäLLo!how  \n Are yoU?  "), ["hallo", "!", "how", "are", "you", "?"]
        )
        self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["hello"])

    def test_basic_tokenizer_lower_strip_accents_default(self):
        tokenizer = BasicTokenizer(do_lower_case=True)

        self.assertListEqual(
            tokenizer.tokenize(" \tHäLLo!how  \n Are yoU?  "), ["hallo", "!", "how", "are", "you", "?"]
        )
        self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["hello"])

157
    def test_basic_tokenizer_no_lower(self):
thomwolf's avatar
thomwolf committed
158
        tokenizer = BasicTokenizer(do_lower_case=False)
thomwolf's avatar
thomwolf committed
159

160
        self.assertListEqual(
161
162
            tokenizer.tokenize(" \tHeLLo!how  \n Are yoU?  "), ["HeLLo", "!", "how", "Are", "yoU", "?"]
        )
thomwolf's avatar
thomwolf committed
163

164
165
166
167
168
169
170
171
172
173
174
175
176
177
    def test_basic_tokenizer_no_lower_strip_accents_false(self):
        tokenizer = BasicTokenizer(do_lower_case=False, strip_accents=False)

        self.assertListEqual(
            tokenizer.tokenize(" \tHäLLo!how  \n Are yoU?  "), ["HäLLo", "!", "how", "Are", "yoU", "?"]
        )

    def test_basic_tokenizer_no_lower_strip_accents_true(self):
        tokenizer = BasicTokenizer(do_lower_case=False, strip_accents=True)

        self.assertListEqual(
            tokenizer.tokenize(" \tHäLLo!how  \n Are yoU?  "), ["HaLLo", "!", "how", "Are", "yoU", "?"]
        )

178
179
180
181
182
183
184
    def test_basic_tokenizer_respects_never_split_tokens(self):
        tokenizer = BasicTokenizer(do_lower_case=False, never_split=["[UNK]"])

        self.assertListEqual(
            tokenizer.tokenize(" \tHeLLo!how  \n Are yoU? [UNK]"), ["HeLLo", "!", "how", "Are", "yoU", "?", "[UNK]"]
        )

185
186
187
188
189
190
    def test_basic_tokenizer_splits_on_punctuation(self):
        tokenizer = BasicTokenizer()
        text = "a\n'll !!to?'d of, can't."
        expected = ["a", "'", "ll", "!", "!", "to", "?", "'", "d", "of", ",", "can", "'", "t", "."]
        self.assertListEqual(tokenizer.tokenize(text), expected)

191
    def test_wordpiece_tokenizer(self):
192
        vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", "##ing"]
thomwolf's avatar
thomwolf committed
193

194
        vocab = {}
Sylvain Gugger's avatar
Sylvain Gugger committed
195
        for i, token in enumerate(vocab_tokens):
196
            vocab[token] = i
197
        tokenizer = WordpieceTokenizer(vocab=vocab, unk_token="[UNK]")
thomwolf's avatar
thomwolf committed
198

199
        self.assertListEqual(tokenizer.tokenize(""), [])
thomwolf's avatar
thomwolf committed
200

201
        self.assertListEqual(tokenizer.tokenize("unwanted running"), ["un", "##want", "##ed", "runn", "##ing"])
thomwolf's avatar
thomwolf committed
202

203
        self.assertListEqual(tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"])
thomwolf's avatar
thomwolf committed
204

205
    def test_is_whitespace(self):
206
207
208
209
210
        self.assertTrue(_is_whitespace(" "))
        self.assertTrue(_is_whitespace("\t"))
        self.assertTrue(_is_whitespace("\r"))
        self.assertTrue(_is_whitespace("\n"))
        self.assertTrue(_is_whitespace("\u00A0"))
thomwolf's avatar
thomwolf committed
211

212
213
        self.assertFalse(_is_whitespace("A"))
        self.assertFalse(_is_whitespace("-"))
thomwolf's avatar
thomwolf committed
214

215
    def test_is_control(self):
216
        self.assertTrue(_is_control("\u0005"))
thomwolf's avatar
thomwolf committed
217

218
219
220
221
        self.assertFalse(_is_control("A"))
        self.assertFalse(_is_control(" "))
        self.assertFalse(_is_control("\t"))
        self.assertFalse(_is_control("\r"))
thomwolf's avatar
thomwolf committed
222

223
    def test_is_punctuation(self):
224
225
226
227
        self.assertTrue(_is_punctuation("-"))
        self.assertTrue(_is_punctuation("$"))
        self.assertTrue(_is_punctuation("`"))
        self.assertTrue(_is_punctuation("."))
thomwolf's avatar
thomwolf committed
228

229
230
        self.assertFalse(_is_punctuation("A"))
        self.assertFalse(_is_punctuation(" "))
thomwolf's avatar
thomwolf committed
231

232
233
234
235
236
237
238
239
240
241
242
    def test_clean_text(self):
        tokenizer = self.get_tokenizer()
        rust_tokenizer = self.get_rust_tokenizer()

        # Example taken from the issue https://github.com/huggingface/tokenizers/issues/340
        self.assertListEqual([tokenizer.tokenize(t) for t in ["Test", "\xad", "test"]], [["[UNK]"], [], ["[UNK]"]])

        self.assertListEqual(
            [rust_tokenizer.tokenize(t) for t in ["Test", "\xad", "test"]], [["[UNK]"], [], ["[UNK]"]]
        )

243
    @slow
244
    def test_sequence_builders(self):
245
        tokenizer = self.tokenizer_class.from_pretrained("google-bert/bert-base-uncased")
246

Lysandre's avatar
Remove  
Lysandre committed
247
248
        text = tokenizer.encode("sequence builders", add_special_tokens=False)
        text_2 = tokenizer.encode("multi-sequence build", add_special_tokens=False)
249

250
251
        encoded_sentence = tokenizer.build_inputs_with_special_tokens(text)
        encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2)
252
253
254

        assert encoded_sentence == [101] + text + [102]
        assert encoded_pair == [101] + text + [102] + text_2 + [102]
255
256
257

    def test_offsets_with_special_characters(self):
        for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
258
            with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
                tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)

                sentence = f"A, naïve {tokenizer_r.mask_token} AllenNLP sentence."
                tokens = tokenizer_r.encode_plus(
                    sentence,
                    return_attention_mask=False,
                    return_token_type_ids=False,
                    return_offsets_mapping=True,
                    add_special_tokens=True,
                )

                do_lower_case = tokenizer_r.do_lower_case if hasattr(tokenizer_r, "do_lower_case") else False
                expected_results = (
                    [
                        ((0, 0), tokenizer_r.cls_token),
                        ((0, 1), "A"),
                        ((1, 2), ","),
                        ((3, 5), "na"),
                        ((5, 6), "##ï"),
                        ((6, 8), "##ve"),
                        ((9, 15), tokenizer_r.mask_token),
                        ((16, 21), "Allen"),
                        ((21, 23), "##NL"),
                        ((23, 24), "##P"),
                        ((25, 33), "sentence"),
                        ((33, 34), "."),
                        ((0, 0), tokenizer_r.sep_token),
                    ]
                    if not do_lower_case
                    else [
                        ((0, 0), tokenizer_r.cls_token),
                        ((0, 1), "a"),
                        ((1, 2), ","),
                        ((3, 8), "naive"),
                        ((9, 15), tokenizer_r.mask_token),
                        ((16, 21), "allen"),
                        ((21, 23), "##nl"),
                        ((23, 24), "##p"),
                        ((25, 33), "sentence"),
                        ((33, 34), "."),
                        ((0, 0), tokenizer_r.sep_token),
                    ]
                )

                self.assertEqual(
                    [e[1] for e in expected_results], tokenizer_r.convert_ids_to_tokens(tokens["input_ids"])
                )
                self.assertEqual([e[0] for e in expected_results], tokens["offset_mapping"])
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342

    def test_change_tokenize_chinese_chars(self):
        list_of_commun_chinese_char = ["的", "人", "有"]
        text_with_chinese_char = "".join(list_of_commun_chinese_char)
        for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
            with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
                kwargs["tokenize_chinese_chars"] = True
                tokenizer_p = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs)
                tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)

                ids_without_spe_char_p = tokenizer_p.encode(text_with_chinese_char, add_special_tokens=False)
                ids_without_spe_char_r = tokenizer_r.encode(text_with_chinese_char, add_special_tokens=False)

                tokens_without_spe_char_r = tokenizer_r.convert_ids_to_tokens(ids_without_spe_char_r)
                tokens_without_spe_char_p = tokenizer_p.convert_ids_to_tokens(ids_without_spe_char_p)

                # it is expected that each Chinese character is not preceded by "##"
                self.assertListEqual(tokens_without_spe_char_p, list_of_commun_chinese_char)
                self.assertListEqual(tokens_without_spe_char_r, list_of_commun_chinese_char)

                kwargs["tokenize_chinese_chars"] = False
                tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)
                tokenizer_p = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs)

                ids_without_spe_char_r = tokenizer_r.encode(text_with_chinese_char, add_special_tokens=False)
                ids_without_spe_char_p = tokenizer_p.encode(text_with_chinese_char, add_special_tokens=False)

                tokens_without_spe_char_r = tokenizer_r.convert_ids_to_tokens(ids_without_spe_char_r)
                tokens_without_spe_char_p = tokenizer_p.convert_ids_to_tokens(ids_without_spe_char_p)

                # it is expected that only the first Chinese character is not preceded by "##".
                expected_tokens = [
                    f"##{token}" if idx != 0 else token for idx, token in enumerate(list_of_commun_chinese_char)
                ]
                self.assertListEqual(tokens_without_spe_char_p, expected_tokens)
                self.assertListEqual(tokens_without_spe_char_r, expected_tokens)