tokenization_tests_commons.py 13.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2019 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
thomwolf's avatar
thomwolf committed
15
from __future__ import absolute_import, division, print_function, unicode_literals
16

thomwolf's avatar
thomwolf committed
17
import os
18
19
import sys
from io import open
20
import tempfile
thomwolf's avatar
thomwolf committed
21
import shutil
22
import unittest
23
24
25

if sys.version_info[0] == 2:
    import cPickle as pickle
thomwolf's avatar
thomwolf committed
26
27
28
29
30
31
32
33

    class TemporaryDirectory(object):
        """Context manager for tempfile.mkdtemp() so it's usable with "with" statement."""
        def __enter__(self):
            self.name = tempfile.mkdtemp()
            return self.name
        def __exit__(self, exc_type, exc_value, traceback):
            shutil.rmtree(self.name)
34
35
else:
    import pickle
thomwolf's avatar
thomwolf committed
36
37
    TemporaryDirectory = tempfile.TemporaryDirectory
    unicode = str
38
39


40
class CommonTestCases:
41

42
    class CommonTokenizerTester(unittest.TestCase):
43

44
        tokenizer_class = None
45

46
47
        def setUp(self):
            self.tmpdirname = tempfile.mkdtemp()
48

49
50
        def tearDown(self):
            shutil.rmtree(self.tmpdirname)
51

52
        def get_tokenizer(self, **kwargs):
53
            raise NotImplementedError
54

55
56
        def get_input_output_texts(self):
            raise NotImplementedError
thomwolf's avatar
thomwolf committed
57

58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
        def test_tokenizers_common_properties(self):
            tokenizer = self.get_tokenizer()
            attributes_list = ["bos_token", "eos_token", "unk_token", "sep_token",
                                "pad_token", "cls_token", "mask_token"]
            for attr in attributes_list:
                self.assertTrue(hasattr(tokenizer, attr))
                self.assertTrue(hasattr(tokenizer, attr + "_id"))

            self.assertTrue(hasattr(tokenizer, "additional_special_tokens"))
            self.assertTrue(hasattr(tokenizer, 'additional_special_tokens_ids'))

            attributes_list = ["max_len", "init_inputs", "init_kwargs", "added_tokens_encoder",
                                "added_tokens_decoder"]
            for attr in attributes_list:
                self.assertTrue(hasattr(tokenizer, attr))

74
        def test_save_and_load_tokenizer(self):
75
            # safety check on max_len default value so we are sure the test works
76
            tokenizer = self.get_tokenizer()
77
78
79
80
            self.assertNotEqual(tokenizer.max_len, 42)

            # Now let's start the test
            tokenizer = self.get_tokenizer(max_len=42)
thomwolf's avatar
thomwolf committed
81

82
            before_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running")
83

84
85
            with TemporaryDirectory() as tmpdirname:
                tokenizer.save_pretrained(tmpdirname)
thomwolf's avatar
thomwolf committed
86
                tokenizer = self.tokenizer_class.from_pretrained(tmpdirname)
87

88
89
90
91
                after_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running")
                self.assertListEqual(before_tokens, after_tokens)

                self.assertEqual(tokenizer.max_len, 42)
thomwolf's avatar
thomwolf committed
92
                tokenizer = self.tokenizer_class.from_pretrained(tmpdirname, max_len=43)
93
                self.assertEqual(tokenizer.max_len, 43)
94

95
96
97
        def test_pickle_tokenizer(self):
            tokenizer = self.get_tokenizer()
            self.assertIsNotNone(tokenizer)
98

99
100
            text = u"Munich and Berlin are nice cities"
            subwords = tokenizer.tokenize(text)
101

102
            with TemporaryDirectory() as tmpdirname:
103

104
105
                filename = os.path.join(tmpdirname, u"tokenizer.bin")
                pickle.dump(tokenizer, open(filename, "wb"))
106

107
                tokenizer_new = pickle.load(open(filename, "rb"))
108

109
            subwords_loaded = tokenizer_new.tokenize(text)
110

111
            self.assertListEqual(subwords, subwords_loaded)
112
113


114
115
        def test_add_tokens_tokenizer(self):
            tokenizer = self.get_tokenizer()
116

117
118
            vocab_size = tokenizer.vocab_size
            all_size = len(tokenizer)
119

120
121
            self.assertNotEqual(vocab_size, 0)
            self.assertEqual(vocab_size, all_size)
122

thomwolf's avatar
thomwolf committed
123
            new_toks = ["aaaaa bbbbbb", "cccccccccdddddddd"]
124
125
126
            added_toks = tokenizer.add_tokens(new_toks)
            vocab_size_2 = tokenizer.vocab_size
            all_size_2 = len(tokenizer)
127

128
129
130
131
            self.assertNotEqual(vocab_size_2, 0)
            self.assertEqual(vocab_size, vocab_size_2)
            self.assertEqual(added_toks, len(new_toks))
            self.assertEqual(all_size_2, all_size + len(new_toks))
132

thomwolf's avatar
thomwolf committed
133
134
135
            tokens = tokenizer.encode("aaaaa bbbbbb low cccccccccdddddddd l")
            out_string = tokenizer.decode(tokens)

136
137
138
            self.assertGreaterEqual(len(tokens), 4)
            self.assertGreater(tokens[0], tokenizer.vocab_size - 1)
            self.assertGreater(tokens[-2], tokenizer.vocab_size - 1)
139

140
            new_toks_2 = {'eos_token': ">>>>|||<||<<|<<",
141
                          'pad_token': "<<<<<|||>|>>>>|>"}
142
143
144
            added_toks_2 = tokenizer.add_special_tokens(new_toks_2)
            vocab_size_3 = tokenizer.vocab_size
            all_size_3 = len(tokenizer)
145

146
147
148
149
            self.assertNotEqual(vocab_size_3, 0)
            self.assertEqual(vocab_size, vocab_size_3)
            self.assertEqual(added_toks_2, len(new_toks_2))
            self.assertEqual(all_size_3, all_size_2 + len(new_toks_2))
150

151
            tokens = tokenizer.encode(">>>>|||<||<<|<< aaaaabbbbbb low cccccccccdddddddd <<<<<|||>|>>>>|> l")
thomwolf's avatar
thomwolf committed
152
            out_string = tokenizer.decode(tokens)
153

154
155
156
157
158
            self.assertGreaterEqual(len(tokens), 6)
            self.assertGreater(tokens[0], tokenizer.vocab_size - 1)
            self.assertGreater(tokens[0], tokens[1])
            self.assertGreater(tokens[-2], tokenizer.vocab_size - 1)
            self.assertGreater(tokens[-2], tokens[-3])
159
            self.assertEqual(tokens[0], tokenizer.eos_token_id)
160
            self.assertEqual(tokens[-2], tokenizer.pad_token_id)
161
162


163
164
165
        def test_required_methods_tokenizer(self):
            tokenizer = self.get_tokenizer()
            input_text, output_text = self.get_input_output_texts()
166

167
168
169
170
            tokens = tokenizer.tokenize(input_text)
            ids = tokenizer.convert_tokens_to_ids(tokens)
            ids_2 = tokenizer.encode(input_text)
            self.assertListEqual(ids, ids_2)
171

172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
            tokens_2 = tokenizer.convert_ids_to_tokens(ids)
            text_2 = tokenizer.decode(ids)

            self.assertEqual(text_2, output_text)

            self.assertNotEqual(len(tokens_2), 0)
            self.assertIsInstance(text_2, (str, unicode))


        def test_pretrained_model_lists(self):
            weights_list = list(self.tokenizer_class.max_model_input_sizes.keys())
            weights_lists_2 = []
            for file_id, map_list in self.tokenizer_class.pretrained_vocab_files_map.items():
                weights_lists_2.append(list(map_list.keys()))

            for weights_list_2 in weights_lists_2:
                self.assertListEqual(weights_list, weights_list_2)
LysandreJik's avatar
LysandreJik committed
189

190
191
192
193
194
195
196
197
198
        def test_mask_output(self):
            if sys.version_info <= (3, 0):
                return

            tokenizer = self.get_tokenizer()

            if tokenizer.add_special_tokens_sequence_pair.__qualname__.split('.')[0] != "PreTrainedTokenizer":
                seq_0 = "Test this method."
                seq_1 = "With these inputs."
thomwolf's avatar
thomwolf committed
199
                information = tokenizer.encode_plus(seq_0, seq_1, add_special_tokens=True)
200
                sequences, mask = information["input_ids"], information["token_type_ids"]
201
                assert len(sequences) == len(mask)
202
203
204
205
206
207
208
209
210
211
212
213

        def test_number_of_added_tokens(self):
            tokenizer = self.get_tokenizer()

            seq_0 = "Test this method."
            seq_1 = "With these inputs."

            sequences = tokenizer.encode(seq_0, seq_1)
            attached_sequences = tokenizer.encode(seq_0, seq_1, add_special_tokens=True)

            # Method is implemented (e.g. not GPT-2)
            if len(attached_sequences) != 2:
214
                assert tokenizer.num_added_tokens(pair=True) == len(attached_sequences) - len(sequences)
215
216
217
218
219

        def test_maximum_encoding_length_single_input(self):
            tokenizer = self.get_tokenizer()

            seq_0 = "This is a sentence to be encoded."
LysandreJik's avatar
LysandreJik committed
220
            stride = 2
221
222
223
224

            sequence = tokenizer.encode(seq_0)
            num_added_tokens = tokenizer.num_added_tokens()
            total_length = len(sequence) + num_added_tokens
LysandreJik's avatar
LysandreJik committed
225
            information = tokenizer.encode_plus(seq_0, max_length=total_length - 2, add_special_tokens=True, stride=stride)
226

LysandreJik's avatar
LysandreJik committed
227
            truncated_sequence = information["input_ids"]
228
229
            overflowing_tokens = information["overflowing_tokens"]

LysandreJik's avatar
LysandreJik committed
230
231
            assert len(overflowing_tokens) == 2 + stride
            assert overflowing_tokens == sequence[-(2 + stride):]
232
            assert len(truncated_sequence) == total_length - 2
233
            assert truncated_sequence == tokenizer.add_special_tokens_single_sequence(sequence[:-2])
234
235
236
237
238
239

        def test_maximum_encoding_length_pair_input(self):
            tokenizer = self.get_tokenizer()

            seq_0 = "This is a sentence to be encoded."
            seq_1 = "This is another sentence to be encoded."
240
241
242
243
            stride = 2

            sequence_0_no_special_tokens = tokenizer.encode(seq_0)
            sequence_1_no_special_tokens = tokenizer.encode(seq_1)
244
245

            sequence = tokenizer.encode(seq_0, seq_1, add_special_tokens=True)
246
            truncated_second_sequence = tokenizer.add_special_tokens_sequence_pair(
247
248
249
                tokenizer.encode(seq_0),
                tokenizer.encode(seq_1)[:-2]
            )
250
251

            information = tokenizer.encode_plus(seq_0, seq_1, max_length=len(sequence) - 2, add_special_tokens=True,
LysandreJik's avatar
LysandreJik committed
252
                                                stride=stride, truncate_first_sequence=False)
253
254
            information_first_truncated = tokenizer.encode_plus(seq_0, seq_1, max_length=len(sequence) - 2,
                                                                add_special_tokens=True, stride=stride,
LysandreJik's avatar
LysandreJik committed
255
                                                                truncate_first_sequence=True)
256

LysandreJik's avatar
LysandreJik committed
257
            truncated_sequence = information["input_ids"]
258
            overflowing_tokens = information["overflowing_tokens"]
259
            overflowing_tokens_first_truncated = information_first_truncated["overflowing_tokens"]
260

261
262
263
            assert len(overflowing_tokens) == 2 + stride
            assert overflowing_tokens == sequence_1_no_special_tokens[-(2 + stride):]
            assert overflowing_tokens_first_truncated == sequence_0_no_special_tokens[-(2 + stride):]
264
265
            assert len(truncated_sequence) == len(sequence) - 2
            assert truncated_sequence == truncated_second_sequence
266

LysandreJik's avatar
LysandreJik committed
267
        def test_encode_input_type(self):
268
269
270
            tokenizer = self.get_tokenizer()

            sequence = "Let's encode this sequence"
LysandreJik's avatar
LysandreJik committed
271
272
273
274
275
276
277

            tokens = tokenizer.tokenize(sequence)
            input_ids = tokenizer.convert_tokens_to_ids(tokens)
            formatted_input = tokenizer.encode(sequence, add_special_tokens=True)

            assert tokenizer.encode(tokens, add_special_tokens=True) == formatted_input
            assert tokenizer.encode(input_ids, add_special_tokens=True) == formatted_input
LysandreJik's avatar
LysandreJik committed
278

279
        def test_special_tokens_mask(self):
LysandreJik's avatar
LysandreJik committed
280
281
282
283
284
285
286
287
288
            tokenizer = self.get_tokenizer()

            sequence_0 = "Encode this."
            sequence_1 = "This one too please."

            # Testing single inputs
            encoded_sequence = tokenizer.encode(sequence_0)
            encoded_sequence_dict = tokenizer.encode_plus(sequence_0, add_special_tokens=True)
            encoded_sequence_w_special = encoded_sequence_dict["input_ids"]
289
290
            special_tokens_mask = encoded_sequence_dict["special_tokens_mask"]
            assert len(special_tokens_mask) == len(encoded_sequence_w_special)
LysandreJik's avatar
LysandreJik committed
291

292
            filtered_sequence = [(x if special_tokens_mask[i] else None) for i, x in enumerate(encoded_sequence_w_special)]
LysandreJik's avatar
LysandreJik committed
293
294
295
296
297
298
299
            filtered_sequence = [x for x in filtered_sequence if x is not None]
            assert encoded_sequence == filtered_sequence

            # Testing inputs pairs
            encoded_sequence = tokenizer.encode(sequence_0) + tokenizer.encode(sequence_1)
            encoded_sequence_dict = tokenizer.encode_plus(sequence_0, sequence_1, add_special_tokens=True)
            encoded_sequence_w_special = encoded_sequence_dict["input_ids"]
300
301
            special_tokens_mask = encoded_sequence_dict["special_tokens_mask"]
            assert len(special_tokens_mask) == len(encoded_sequence_w_special)
LysandreJik's avatar
LysandreJik committed
302

303
            filtered_sequence = [(x if special_tokens_mask[i] else None) for i, x in enumerate(encoded_sequence_w_special)]
LysandreJik's avatar
LysandreJik committed
304
305
306
            filtered_sequence = [x for x in filtered_sequence if x is not None]
            assert encoded_sequence == filtered_sequence

307
308
309
310
311
            # Testing with already existing special tokens
            if tokenizer.cls_token_id == tokenizer.unk_token_id and tokenizer.cls_token_id == tokenizer.unk_token_id:
                tokenizer.add_special_tokens({'cls_token': '</s>', 'sep_token': '<s>'})
            encoded_sequence_dict = tokenizer.encode_plus(sequence_0, add_special_tokens=True)
            encoded_sequence_w_special = encoded_sequence_dict["input_ids"]
312
313
314
315
            special_tokens_mask_orig = encoded_sequence_dict["special_tokens_mask"]
            special_tokens_mask = tokenizer.get_special_tokens_mask(encoded_sequence_w_special, special_tokens_present=True)
            assert len(special_tokens_mask) == len(encoded_sequence_w_special)
            assert special_tokens_mask_orig == special_tokens_mask
316
317


LysandreJik's avatar
LysandreJik committed
318