test_tokenization_common.py 34.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2019 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Aymeric Augustin's avatar
Aymeric Augustin committed
15

16

thomwolf's avatar
thomwolf committed
17
import os
18
import pickle
Aymeric Augustin's avatar
Aymeric Augustin committed
19
import shutil
20
import tempfile
21
from collections import OrderedDict
22
from typing import TYPE_CHECKING, Dict, Tuple, Union
Aymeric Augustin's avatar
Aymeric Augustin committed
23

24
25
from tests.utils import require_tf, require_torch

26

27
28
29
30
31
32
33
34
35
36
if TYPE_CHECKING:
    from transformers import (
        PretrainedConfig,
        PreTrainedTokenizer,
        PreTrainedTokenizerFast,
        PreTrainedModel,
        TFPreTrainedModel,
    )


37
def merge_model_tokenizer_mappings(
LysandreJik's avatar
LysandreJik committed
38
39
40
41
42
43
    model_mapping: Dict["PretrainedConfig", Union["PreTrainedModel", "TFPreTrainedModel"]],
    tokenizer_mapping: Dict["PretrainedConfig", Tuple["PreTrainedTokenizer", "PreTrainedTokenizerFast"]],
) -> Dict[
    Union["PreTrainedTokenizer", "PreTrainedTokenizerFast"],
    Tuple["PretrainedConfig", Union["PreTrainedModel", "TFPreTrainedModel"]],
]:
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
    configurations = list(model_mapping.keys())
    model_tokenizer_mapping = OrderedDict([])

    for configuration in configurations:
        model = model_mapping[configuration]
        tokenizer = tokenizer_mapping[configuration][0]
        tokenizer_fast = tokenizer_mapping[configuration][1]

        model_tokenizer_mapping.update({tokenizer: (configuration, model)})
        if tokenizer_fast is not None:
            model_tokenizer_mapping.update({tokenizer_fast: (configuration, model)})

    return model_tokenizer_mapping


59
class TokenizerTesterMixin:
60

61
    tokenizer_class = None
Anthony MOI's avatar
Anthony MOI committed
62
    test_rust_tokenizer = False
63

64
65
    def setUp(self):
        self.tmpdirname = tempfile.mkdtemp()
66

67
68
    def tearDown(self):
        shutil.rmtree(self.tmpdirname)
69

70
71
    def get_tokenizer(self, **kwargs):
        raise NotImplementedError
72

Anthony MOI's avatar
Anthony MOI committed
73
74
    def get_rust_tokenizer(self, **kwargs):
        raise NotImplementedError
75

76
77
    def get_input_output_texts(self):
        raise NotImplementedError
thomwolf's avatar
thomwolf committed
78

79
80
81
82
83
84
    @staticmethod
    def convert_batch_encode_plus_format_to_encode_plus(batch_encode_plus_sequences):
        # Switch from batch_encode_plus format:   {'input_ids': [[...], [...]], ...}
        # to the concatenated encode_plus format: [{'input_ids': [...], ...}, {'input_ids': [...], ...}]
        return [
            {value: batch_encode_plus_sequences[value][i] for value in batch_encode_plus_sequences.keys()}
Lysandre Debut's avatar
Lysandre Debut committed
85
            for i in range(len(batch_encode_plus_sequences["input_ids"]))
86
87
        ]

88
89
90
91
92
93
94
95
96
97
98
99
100
101
    def test_tokenizers_common_properties(self):
        tokenizer = self.get_tokenizer()
        attributes_list = [
            "bos_token",
            "eos_token",
            "unk_token",
            "sep_token",
            "pad_token",
            "cls_token",
            "mask_token",
        ]
        for attr in attributes_list:
            self.assertTrue(hasattr(tokenizer, attr))
            self.assertTrue(hasattr(tokenizer, attr + "_id"))
102

103
104
        self.assertTrue(hasattr(tokenizer, "additional_special_tokens"))
        self.assertTrue(hasattr(tokenizer, "additional_special_tokens_ids"))
105

106
107
108
        attributes_list = ["max_len", "init_inputs", "init_kwargs", "added_tokens_encoder", "added_tokens_decoder"]
        for attr in attributes_list:
            self.assertTrue(hasattr(tokenizer, attr))
109

110
111
112
113
    def test_save_and_load_tokenizer(self):
        # safety check on max_len default value so we are sure the test works
        tokenizer = self.get_tokenizer()
        self.assertNotEqual(tokenizer.max_len, 42)
114

115
116
        # Now let's start the test
        tokenizer = self.get_tokenizer(max_len=42)
thomwolf's avatar
thomwolf committed
117

118
        before_tokens = tokenizer.encode("He is very happy, UNwant\u00E9d,running", add_special_tokens=False)
119

120
121
        tokenizer.save_pretrained(self.tmpdirname)
        tokenizer = self.tokenizer_class.from_pretrained(self.tmpdirname)
122

123
124
        after_tokens = tokenizer.encode("He is very happy, UNwant\u00E9d,running", add_special_tokens=False)
        self.assertListEqual(before_tokens, after_tokens)
125

126
127
128
        self.assertEqual(tokenizer.max_len, 42)
        tokenizer = self.tokenizer_class.from_pretrained(self.tmpdirname, max_len=43)
        self.assertEqual(tokenizer.max_len, 43)
129

130
131
132
    def test_pickle_tokenizer(self):
        tokenizer = self.get_tokenizer()
        self.assertIsNotNone(tokenizer)
133

134
135
        text = "Munich and Berlin are nice cities"
        subwords = tokenizer.tokenize(text)
136

137
138
139
        filename = os.path.join(self.tmpdirname, "tokenizer.bin")
        with open(filename, "wb") as handle:
            pickle.dump(tokenizer, handle)
140

141
142
        with open(filename, "rb") as handle:
            tokenizer_new = pickle.load(handle)
143

144
        subwords_loaded = tokenizer_new.tokenize(text)
145

146
        self.assertListEqual(subwords, subwords_loaded)
147

148
149
    def test_added_tokens_do_lower_case(self):
        tokenizer = self.get_tokenizer(do_lower_case=True)
150

151
        special_token = tokenizer.all_special_tokens[0]
152

153
154
        text = special_token + " aaaaa bbbbbb low cccccccccdddddddd l " + special_token
        text2 = special_token + " AAAAA BBBBBB low CCCCCCCCCDDDDDDDD l " + special_token
155

156
        toks0 = tokenizer.tokenize(text)  # toks before adding new_toks
157

158
159
160
        new_toks = ["aaaaa bbbbbb", "cccccccccdddddddd", "AAAAA BBBBBB", "CCCCCCCCCDDDDDDDD"]
        added = tokenizer.add_tokens(new_toks)
        self.assertEqual(added, 2)
161

162
163
        toks = tokenizer.tokenize(text)
        toks2 = tokenizer.tokenize(text2)
164

165
166
167
        self.assertEqual(len(toks), len(toks2))
        self.assertNotEqual(len(toks), len(toks0))  # toks0 should be longer
        self.assertListEqual(toks, toks2)
168

169
170
171
        # Check that none of the special tokens are lowercased
        sequence_with_special_tokens = "A " + " yEs ".join(tokenizer.all_special_tokens) + " B"
        tokenized_sequence = tokenizer.tokenize(sequence_with_special_tokens)
Lysandre's avatar
Lysandre committed
172

173
174
        for special_token in tokenizer.all_special_tokens:
            self.assertTrue(special_token in tokenized_sequence)
Lysandre's avatar
Lysandre committed
175

176
        tokenizer = self.get_tokenizer(do_lower_case=False)
177

178
179
        added = tokenizer.add_tokens(new_toks)
        self.assertEqual(added, 4)
180

181
182
        toks = tokenizer.tokenize(text)
        toks2 = tokenizer.tokenize(text2)
183

184
185
186
        self.assertEqual(len(toks), len(toks2))  # Length should still be the same
        self.assertNotEqual(len(toks), len(toks0))
        self.assertNotEqual(toks[1], toks2[1])  # But at least the first non-special tokens should differ
187

188
189
    def test_add_tokens_tokenizer(self):
        tokenizer = self.get_tokenizer()
190

191
192
        vocab_size = tokenizer.vocab_size
        all_size = len(tokenizer)
193

194
195
        self.assertNotEqual(vocab_size, 0)
        self.assertEqual(vocab_size, all_size)
196

197
198
199
200
        new_toks = ["aaaaa bbbbbb", "cccccccccdddddddd"]
        added_toks = tokenizer.add_tokens(new_toks)
        vocab_size_2 = tokenizer.vocab_size
        all_size_2 = len(tokenizer)
201

202
203
204
205
        self.assertNotEqual(vocab_size_2, 0)
        self.assertEqual(vocab_size, vocab_size_2)
        self.assertEqual(added_toks, len(new_toks))
        self.assertEqual(all_size_2, all_size + len(new_toks))
206

207
        tokens = tokenizer.encode("aaaaa bbbbbb low cccccccccdddddddd l", add_special_tokens=False)
thomwolf's avatar
thomwolf committed
208

209
210
211
        self.assertGreaterEqual(len(tokens), 4)
        self.assertGreater(tokens[0], tokenizer.vocab_size - 1)
        self.assertGreater(tokens[-2], tokenizer.vocab_size - 1)
212

213
214
215
216
        new_toks_2 = {"eos_token": ">>>>|||<||<<|<<", "pad_token": "<<<<<|||>|>>>>|>"}
        added_toks_2 = tokenizer.add_special_tokens(new_toks_2)
        vocab_size_3 = tokenizer.vocab_size
        all_size_3 = len(tokenizer)
217

218
219
220
221
        self.assertNotEqual(vocab_size_3, 0)
        self.assertEqual(vocab_size, vocab_size_3)
        self.assertEqual(added_toks_2, len(new_toks_2))
        self.assertEqual(all_size_3, all_size_2 + len(new_toks_2))
222

223
224
225
        tokens = tokenizer.encode(
            ">>>>|||<||<<|<< aaaaabbbbbb low cccccccccdddddddd <<<<<|||>|>>>>|> l", add_special_tokens=False
        )
226

227
228
229
230
231
232
233
        self.assertGreaterEqual(len(tokens), 6)
        self.assertGreater(tokens[0], tokenizer.vocab_size - 1)
        self.assertGreater(tokens[0], tokens[1])
        self.assertGreater(tokens[-2], tokenizer.vocab_size - 1)
        self.assertGreater(tokens[-2], tokens[-3])
        self.assertEqual(tokens[0], tokenizer.eos_token_id)
        self.assertEqual(tokens[-2], tokenizer.pad_token_id)
234

235
236
237
    def test_add_special_tokens(self):
        tokenizer = self.get_tokenizer()
        input_text, output_text = self.get_input_output_texts()
238

239
        special_token = "[SPECIAL TOKEN]"
240

241
242
243
        tokenizer.add_special_tokens({"cls_token": special_token})
        encoded_special_token = tokenizer.encode(special_token, add_special_tokens=False)
        assert len(encoded_special_token) == 1
244

245
246
        text = " ".join([input_text, special_token, output_text])
        encoded = tokenizer.encode(text, add_special_tokens=False)
247

248
        input_encoded = tokenizer.encode(input_text, add_special_tokens=False)
249
        output_encoded = tokenizer.encode(" " + output_text, add_special_tokens=False)
250
251
        special_token_id = tokenizer.encode(special_token, add_special_tokens=False)
        assert encoded == input_encoded + special_token_id + output_encoded
252

253
254
        decoded = tokenizer.decode(encoded, skip_special_tokens=True)
        assert special_token not in decoded
255

256
257
258
    def test_required_methods_tokenizer(self):
        tokenizer = self.get_tokenizer()
        input_text, output_text = self.get_input_output_texts()
259

260
261
262
263
        tokens = tokenizer.tokenize(input_text)
        ids = tokenizer.convert_tokens_to_ids(tokens)
        ids_2 = tokenizer.encode(input_text, add_special_tokens=False)
        self.assertListEqual(ids, ids_2)
264

265
266
        tokens_2 = tokenizer.convert_ids_to_tokens(ids)
        text_2 = tokenizer.decode(ids)
267

268
        self.assertEqual(text_2, output_text)
269

270
        self.assertNotEqual(len(tokens_2), 0)
271
        self.assertIsInstance(text_2, str)
272

273
274
    def test_encode_decode_with_spaces(self):
        tokenizer = self.get_tokenizer()
LysandreJik's avatar
LysandreJik committed
275

276
277
278
279
280
281
        new_toks = ["[ABC]", "[DEF]", "GHI IHG"]
        tokenizer.add_tokens(new_toks)
        input = "[ABC] [DEF] [ABC] GHI IHG [DEF]"
        encoded = tokenizer.encode(input, add_special_tokens=False)
        decoded = tokenizer.decode(encoded)
        self.assertEqual(decoded, input)
282

283
284
285
286
287
    def test_pretrained_model_lists(self):
        weights_list = list(self.tokenizer_class.max_model_input_sizes.keys())
        weights_lists_2 = []
        for file_id, map_list in self.tokenizer_class.pretrained_vocab_files_map.items():
            weights_lists_2.append(list(map_list.keys()))
288

289
290
        for weights_list_2 in weights_lists_2:
            self.assertListEqual(weights_list, weights_list_2)
LysandreJik's avatar
LysandreJik committed
291

292
293
    def test_mask_output(self):
        tokenizer = self.get_tokenizer()
294

Lysandre Debut's avatar
Lysandre Debut committed
295
296
297
298
        if (
            tokenizer.build_inputs_with_special_tokens.__qualname__.split(".")[0] != "PreTrainedTokenizer"
            and "token_type_ids" in tokenizer.model_input_names
        ):
299
300
            seq_0 = "Test this method."
            seq_1 = "With these inputs."
301
302
303
304
305
306
307
308
309
310
311
            information = tokenizer.encode_plus(seq_0, seq_1, add_special_tokens=True)
            sequences, mask = information["input_ids"], information["token_type_ids"]
            self.assertEqual(len(sequences), len(mask))

    def test_number_of_added_tokens(self):
        tokenizer = self.get_tokenizer()

        seq_0 = "Test this method."
        seq_1 = "With these inputs."

        sequences = tokenizer.encode(seq_0, seq_1, add_special_tokens=False)
312
        attached_sequences = tokenizer.encode(seq_0, seq_1, add_special_tokens=True, add_prefix_space=False)
313
314
315

        # Method is implemented (e.g. not GPT-2)
        if len(attached_sequences) != 2:
Funtowicz Morgan's avatar
Funtowicz Morgan committed
316
            self.assertEqual(tokenizer.num_special_tokens_to_add(pair=True), len(attached_sequences) - len(sequences))
317
318
319
320
321
322
323
324

    def test_maximum_encoding_length_single_input(self):
        tokenizer = self.get_tokenizer()

        seq_0 = "This is a sentence to be encoded."
        stride = 2

        sequence = tokenizer.encode(seq_0, add_special_tokens=False)
Funtowicz Morgan's avatar
Funtowicz Morgan committed
325
        num_added_tokens = tokenizer.num_special_tokens_to_add()
326
327
        total_length = len(sequence) + num_added_tokens
        information = tokenizer.encode_plus(
328
329
330
331
332
333
            seq_0,
            max_length=total_length - 2,
            add_special_tokens=True,
            stride=stride,
            return_overflowing_tokens=True,
            add_prefix_space=False,
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
        )

        truncated_sequence = information["input_ids"]
        overflowing_tokens = information["overflowing_tokens"]

        self.assertEqual(len(overflowing_tokens), 2 + stride)
        self.assertEqual(overflowing_tokens, sequence[-(2 + stride) :])
        self.assertEqual(len(truncated_sequence), total_length - 2)
        self.assertEqual(truncated_sequence, tokenizer.build_inputs_with_special_tokens(sequence[:-2]))

    def test_maximum_encoding_length_pair_input(self):
        tokenizer = self.get_tokenizer()

        seq_0 = "This is a sentence to be encoded."
        seq_1 = "This is another sentence to be encoded."
        stride = 2

        sequence_0_no_special_tokens = tokenizer.encode(seq_0, add_special_tokens=False)
        sequence_1_no_special_tokens = tokenizer.encode(seq_1, add_special_tokens=False)

354
        sequence = tokenizer.encode(seq_0, seq_1, add_special_tokens=True, add_prefix_space=False)
355
356
357
358
359
360
361
362
363
364
365
366
        truncated_second_sequence = tokenizer.build_inputs_with_special_tokens(
            tokenizer.encode(seq_0, add_special_tokens=False), tokenizer.encode(seq_1, add_special_tokens=False)[:-2],
        )

        information = tokenizer.encode_plus(
            seq_0,
            seq_1,
            max_length=len(sequence) - 2,
            add_special_tokens=True,
            stride=stride,
            truncation_strategy="only_second",
            return_overflowing_tokens=True,
367
            add_prefix_space=False,
368
369
370
371
372
373
374
375
376
        )
        information_first_truncated = tokenizer.encode_plus(
            seq_0,
            seq_1,
            max_length=len(sequence) - 2,
            add_special_tokens=True,
            stride=stride,
            truncation_strategy="only_first",
            return_overflowing_tokens=True,
377
            add_prefix_space=False,
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
        )

        truncated_sequence = information["input_ids"]
        overflowing_tokens = information["overflowing_tokens"]
        overflowing_tokens_first_truncated = information_first_truncated["overflowing_tokens"]

        self.assertEqual(len(overflowing_tokens), 2 + stride)
        self.assertEqual(overflowing_tokens, sequence_1_no_special_tokens[-(2 + stride) :])
        self.assertEqual(overflowing_tokens_first_truncated, sequence_0_no_special_tokens[-(2 + stride) :])
        self.assertEqual(len(truncated_sequence), len(sequence) - 2)
        self.assertEqual(truncated_sequence, truncated_second_sequence)

    def test_encode_input_type(self):
        tokenizer = self.get_tokenizer()

        sequence = "Let's encode this sequence"

        tokens = tokenizer.tokenize(sequence)
        input_ids = tokenizer.convert_tokens_to_ids(tokens)
397
        formatted_input = tokenizer.encode(sequence, add_special_tokens=True, add_prefix_space=False)
398
399
400
401

        self.assertEqual(tokenizer.encode(tokens, add_special_tokens=True), formatted_input)
        self.assertEqual(tokenizer.encode(input_ids, add_special_tokens=True), formatted_input)

402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
    def test_swap_special_token(self):
        tokenizer = self.get_tokenizer()

        mask = "<mask>"
        sequence = "Encode this sequence"
        sequence_masked_0 = "Encode <mask> sequence"
        sequence_masked_1 = "<mask> this sequence"

        # Add tokens so that masked token isn't split
        tokenizer.add_tokens(sequence.split())
        tokenizer.add_special_tokens({"mask_token": mask})
        mask_ind = tokenizer.convert_tokens_to_ids(mask)
        encoded = tokenizer.encode(sequence, add_special_tokens=False)

        # Test first masked sequence
        encoded_masked = tokenizer.encode(sequence_masked_0, add_special_tokens=False)
        mask_loc = encoded_masked.index(mask_ind)
        encoded_masked[mask_loc] = encoded[mask_loc]

        self.assertEqual(encoded_masked, encoded)

        # Test second masked sequence
        encoded_masked = tokenizer.encode(sequence_masked_1, add_special_tokens=False)
        mask_loc = encoded_masked.index(mask_ind)
        encoded_masked[mask_loc] = encoded[mask_loc]

        self.assertEqual(encoded_masked, encoded)

430
431
432
433
434
435
436
437
438
    def test_special_tokens_mask(self):
        tokenizer = self.get_tokenizer()

        sequence_0 = "Encode this."
        sequence_1 = "This one too please."

        # Testing single inputs
        encoded_sequence = tokenizer.encode(sequence_0, add_special_tokens=False)
        encoded_sequence_dict = tokenizer.encode_plus(
439
            sequence_0, add_special_tokens=True, return_special_tokens_mask=True, add_prefix_space=False
440
441
442
443
444
445
446
447
448
449
450
451
        )
        encoded_sequence_w_special = encoded_sequence_dict["input_ids"]
        special_tokens_mask = encoded_sequence_dict["special_tokens_mask"]
        self.assertEqual(len(special_tokens_mask), len(encoded_sequence_w_special))

        filtered_sequence = [
            (x if not special_tokens_mask[i] else None) for i, x in enumerate(encoded_sequence_w_special)
        ]
        filtered_sequence = [x for x in filtered_sequence if x is not None]
        self.assertEqual(encoded_sequence, filtered_sequence)

        # Testing inputs pairs
452
453
        encoded_sequence = tokenizer.encode(sequence_0, add_special_tokens=False)
        encoded_sequence += tokenizer.encode(sequence_1, add_special_tokens=False)
454
        encoded_sequence_dict = tokenizer.encode_plus(
455
            sequence_0, sequence_1, add_special_tokens=True, return_special_tokens_mask=True, add_prefix_space=False
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
        )
        encoded_sequence_w_special = encoded_sequence_dict["input_ids"]
        special_tokens_mask = encoded_sequence_dict["special_tokens_mask"]
        self.assertEqual(len(special_tokens_mask), len(encoded_sequence_w_special))

        filtered_sequence = [
            (x if not special_tokens_mask[i] else None) for i, x in enumerate(encoded_sequence_w_special)
        ]
        filtered_sequence = [x for x in filtered_sequence if x is not None]
        self.assertEqual(encoded_sequence, filtered_sequence)

        # Testing with already existing special tokens
        if tokenizer.cls_token_id == tokenizer.unk_token_id and tokenizer.cls_token_id == tokenizer.unk_token_id:
            tokenizer.add_special_tokens({"cls_token": "</s>", "sep_token": "<s>"})
        encoded_sequence_dict = tokenizer.encode_plus(
            sequence_0, add_special_tokens=True, return_special_tokens_mask=True
        )
        encoded_sequence_w_special = encoded_sequence_dict["input_ids"]
        special_tokens_mask_orig = encoded_sequence_dict["special_tokens_mask"]
        special_tokens_mask = tokenizer.get_special_tokens_mask(
            encoded_sequence_w_special, already_has_special_tokens=True
        )
        self.assertEqual(len(special_tokens_mask), len(encoded_sequence_w_special))
        self.assertEqual(special_tokens_mask_orig, special_tokens_mask)

    def test_padding_to_max_length(self):
        tokenizer = self.get_tokenizer()

        sequence = "Sequence"
        padding_size = 10
486
487
488
489

        # check correct behaviour if no pad_token_id exists and add it eventually
        self._check_no_pad_token_padding(tokenizer, sequence)

490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
        padding_idx = tokenizer.pad_token_id

        # RIGHT PADDING - Check that it correctly pads when a maximum length is specified along with the padding flag set to True
        tokenizer.padding_side = "right"
        encoded_sequence = tokenizer.encode(sequence)
        sequence_length = len(encoded_sequence)
        padded_sequence = tokenizer.encode(sequence, max_length=sequence_length + padding_size, pad_to_max_length=True)
        padded_sequence_length = len(padded_sequence)
        assert sequence_length + padding_size == padded_sequence_length
        assert encoded_sequence + [padding_idx] * padding_size == padded_sequence

        # LEFT PADDING - Check that it correctly pads when a maximum length is specified along with the padding flag set to True
        tokenizer.padding_side = "left"
        encoded_sequence = tokenizer.encode(sequence)
        sequence_length = len(encoded_sequence)
        padded_sequence = tokenizer.encode(sequence, max_length=sequence_length + padding_size, pad_to_max_length=True)
        padded_sequence_length = len(padded_sequence)
        assert sequence_length + padding_size == padded_sequence_length
        assert [padding_idx] * padding_size + encoded_sequence == padded_sequence

        # RIGHT & LEFT PADDING - Check that nothing is done when a maximum length is not specified
        encoded_sequence = tokenizer.encode(sequence)
        sequence_length = len(encoded_sequence)

        tokenizer.padding_side = "right"
        padded_sequence_right = tokenizer.encode(sequence, pad_to_max_length=True)
        padded_sequence_right_length = len(padded_sequence_right)

        tokenizer.padding_side = "left"
        padded_sequence_left = tokenizer.encode(sequence, pad_to_max_length=True)
        padded_sequence_left_length = len(padded_sequence_left)

        assert sequence_length == padded_sequence_right_length
        assert encoded_sequence == padded_sequence_right
        assert sequence_length == padded_sequence_left_length
        assert encoded_sequence == padded_sequence_left

    def test_encode_plus_with_padding(self):
        tokenizer = self.get_tokenizer()

        sequence = "Sequence"
531
532
533
534

        # check correct behaviour if no pad_token_id exists and add it eventually
        self._check_no_pad_token_padding(tokenizer, sequence)

535
536
537
538
539
540
541
542
543
544
545
        padding_size = 10
        padding_idx = tokenizer.pad_token_id
        token_type_padding_idx = tokenizer.pad_token_type_id

        encoded_sequence = tokenizer.encode_plus(sequence, return_special_tokens_mask=True)
        input_ids = encoded_sequence["input_ids"]
        special_tokens_mask = encoded_sequence["special_tokens_mask"]
        sequence_length = len(input_ids)

        # Test right padding
        tokenizer.padding_side = "right"
546

Lysandre Debut's avatar
Lysandre Debut committed
547
        right_padded_sequence = tokenizer.encode_plus(
548
549
550
551
552
            sequence,
            max_length=sequence_length + padding_size,
            pad_to_max_length=True,
            return_special_tokens_mask=True,
        )
Lysandre Debut's avatar
Lysandre Debut committed
553
        right_padded_input_ids = right_padded_sequence["input_ids"]
554

Lysandre Debut's avatar
Lysandre Debut committed
555
556
557
558
559
560
        right_padded_special_tokens_mask = right_padded_sequence["special_tokens_mask"]
        right_padded_sequence_length = len(right_padded_input_ids)

        assert sequence_length + padding_size == right_padded_sequence_length
        assert input_ids + [padding_idx] * padding_size == right_padded_input_ids
        assert special_tokens_mask + [1] * padding_size == right_padded_special_tokens_mask
561
562
563

        # Test left padding
        tokenizer.padding_side = "left"
Lysandre Debut's avatar
Lysandre Debut committed
564
        left_padded_sequence = tokenizer.encode_plus(
565
566
567
568
569
            sequence,
            max_length=sequence_length + padding_size,
            pad_to_max_length=True,
            return_special_tokens_mask=True,
        )
Lysandre Debut's avatar
Lysandre Debut committed
570
571
572
        left_padded_input_ids = left_padded_sequence["input_ids"]
        left_padded_special_tokens_mask = left_padded_sequence["special_tokens_mask"]
        left_padded_sequence_length = len(left_padded_input_ids)
573

Lysandre Debut's avatar
Lysandre Debut committed
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
        assert sequence_length + padding_size == left_padded_sequence_length
        assert [padding_idx] * padding_size + input_ids == left_padded_input_ids
        assert [1] * padding_size + special_tokens_mask == left_padded_special_tokens_mask

        if "token_type_ids" in tokenizer.model_input_names:
            token_type_ids = encoded_sequence["token_type_ids"]
            left_padded_token_type_ids = left_padded_sequence["token_type_ids"]
            right_padded_token_type_ids = right_padded_sequence["token_type_ids"]

            assert token_type_ids + [token_type_padding_idx] * padding_size == right_padded_token_type_ids
            assert [token_type_padding_idx] * padding_size + token_type_ids == left_padded_token_type_ids

        if "attention_mask" in tokenizer.model_input_names:
            attention_mask = encoded_sequence["attention_mask"]
            right_padded_attention_mask = right_padded_sequence["attention_mask"]
            left_padded_attention_mask = left_padded_sequence["attention_mask"]

            assert attention_mask + [0] * padding_size == right_padded_attention_mask
            assert [0] * padding_size + attention_mask == left_padded_attention_mask
593
594
595
596
597
598

    def test_separate_tokenizers(self):
        # This tests that tokenizers don't impact others. Unfortunately the case where it fails is when
        # we're loading an S3 configuration from a pre-trained identifier, and we have no way of testing those today.

        tokenizer = self.get_tokenizer(random_argument=True)
Lysandre's avatar
Style  
Lysandre committed
599
        assert tokenizer.init_kwargs["random_argument"] is True
600
        new_tokenizer = self.get_tokenizer(random_argument=False)
Lysandre's avatar
Style  
Lysandre committed
601
602
        assert tokenizer.init_kwargs["random_argument"] is True
        assert new_tokenizer.init_kwargs["random_argument"] is False
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622

    def test_get_vocab(self):
        tokenizer = self.get_tokenizer()
        vocab = tokenizer.get_vocab()

        self.assertIsInstance(vocab, dict)
        self.assertEqual(len(vocab), len(tokenizer))

        for word, ind in vocab.items():
            self.assertEqual(tokenizer.convert_tokens_to_ids(word), ind)
            self.assertEqual(tokenizer.convert_ids_to_tokens(ind), word)

        tokenizer.add_tokens(["asdfasdfasdfasdf"])
        vocab = tokenizer.get_vocab()
        self.assertIsInstance(vocab, dict)
        self.assertEqual(len(vocab), len(tokenizer))

        for word, ind in vocab.items():
            self.assertEqual(tokenizer.convert_tokens_to_ids(word), ind)
            self.assertEqual(tokenizer.convert_ids_to_tokens(ind), word)
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640

    def test_batch_encode_plus_batch_sequence_length(self):
        # Tests that all encoded values have the correct size
        tokenizer = self.get_tokenizer()
        sequences = [
            "Testing batch encode plus",
            "Testing batch encode plus with different sequence lengths",
            "Testing batch encode plus with different sequence lengths correctly pads",
        ]

        encoded_sequences = [tokenizer.encode_plus(sequence, pad_to_max_length=False) for sequence in sequences]
        encoded_sequences_batch = tokenizer.batch_encode_plus(sequences)
        self.assertListEqual(
            encoded_sequences, self.convert_batch_encode_plus_format_to_encode_plus(encoded_sequences_batch)
        )

        maximum_length = len(max([encoded_sequence["input_ids"] for encoded_sequence in encoded_sequences], key=len))

641
642
643
        # check correct behaviour if no pad_token_id exists and add it eventually
        self._check_no_pad_token_padding(tokenizer, sequences)

644
645
646
647
        encoded_sequences_padded = [
            tokenizer.encode_plus(sequence, pad_to_max_length=True, max_length=maximum_length)
            for sequence in sequences
        ]
648

649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
        encoded_sequences_batch_padded = tokenizer.batch_encode_plus(sequences, pad_to_max_length=True)
        self.assertListEqual(
            encoded_sequences_padded,
            self.convert_batch_encode_plus_format_to_encode_plus(encoded_sequences_batch_padded),
        )

    def test_batch_encode_plus_padding(self):
        # Test that padded sequences are equivalent between batch_encode_plus and encode_plus

        # Right padding tests
        tokenizer = self.get_tokenizer()
        sequences = [
            "Testing batch encode plus",
            "Testing batch encode plus with different sequence lengths",
            "Testing batch encode plus with different sequence lengths correctly pads",
        ]

        max_length = 100
667
668
669
670

        # check correct behaviour if no pad_token_id exists and add it eventually
        self._check_no_pad_token_padding(tokenizer, sequences)

671
672
673
674
675
676
677
678
679
680
        encoded_sequences = [
            tokenizer.encode_plus(sequence, pad_to_max_length=True, max_length=max_length) for sequence in sequences
        ]
        encoded_sequences_batch = tokenizer.batch_encode_plus(sequences, pad_to_max_length=True, max_length=max_length)
        self.assertListEqual(
            encoded_sequences, self.convert_batch_encode_plus_format_to_encode_plus(encoded_sequences_batch)
        )

        # Left padding tests
        tokenizer = self.get_tokenizer()
681

682
683
684
685
686
687
688
689
        tokenizer.padding_side = "left"
        sequences = [
            "Testing batch encode plus",
            "Testing batch encode plus with different sequence lengths",
            "Testing batch encode plus with different sequence lengths correctly pads",
        ]

        max_length = 100
690
691
692
693

        # check correct behaviour if no pad_token_id exists and add it eventually
        self._check_no_pad_token_padding(tokenizer, sequences)

694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
        encoded_sequences = [
            tokenizer.encode_plus(sequence, pad_to_max_length=True, max_length=max_length) for sequence in sequences
        ]
        encoded_sequences_batch = tokenizer.batch_encode_plus(sequences, pad_to_max_length=True, max_length=max_length)
        self.assertListEqual(
            encoded_sequences, self.convert_batch_encode_plus_format_to_encode_plus(encoded_sequences_batch)
        )

    @require_torch
    @require_tf
    def test_batch_encode_plus_tensors(self):
        tokenizer = self.get_tokenizer()
        sequences = [
            "Testing batch encode plus",
            "Testing batch encode plus with different sequence lengths",
            "Testing batch encode plus with different sequence lengths correctly pads",
        ]

        # A Tensor cannot be build by sequences which are not the same size
        self.assertRaises(ValueError, tokenizer.batch_encode_plus, sequences, return_tensors="pt")
        self.assertRaises(ValueError, tokenizer.batch_encode_plus, sequences, return_tensors="tf")

        if tokenizer.pad_token_id is None:
            self.assertRaises(
                ValueError, tokenizer.batch_encode_plus, sequences, pad_to_max_length=True, return_tensors="pt"
            )
            self.assertRaises(
                ValueError, tokenizer.batch_encode_plus, sequences, pad_to_max_length=True, return_tensors="tf"
            )
        else:
            pytorch_tensor = tokenizer.batch_encode_plus(sequences, pad_to_max_length=True, return_tensors="pt")
            tensorflow_tensor = tokenizer.batch_encode_plus(sequences, pad_to_max_length=True, return_tensors="tf")
            encoded_sequences = tokenizer.batch_encode_plus(sequences, pad_to_max_length=True)

            for key in encoded_sequences.keys():
                pytorch_value = pytorch_tensor[key].tolist()
                tensorflow_value = tensorflow_tensor[key].numpy().tolist()
                encoded_value = encoded_sequences[key]

                self.assertEqual(pytorch_value, tensorflow_value, encoded_value)
734
735
736
737
738
739
740
741
742
743
744
745

    def _check_no_pad_token_padding(self, tokenizer, sequences):
        # if tokenizer does not have pad_token_id, an error should be thrown
        if tokenizer.pad_token_id is None:
            with self.assertRaises(ValueError):
                if isinstance(sequences, list):
                    tokenizer.batch_encode_plus(sequences, pad_to_max_length=True)
                else:
                    tokenizer.encode_plus(sequences, pad_to_max_length=True)

            # add pad_token_id to pass subsequent tests
            tokenizer.add_special_tokens({"pad_token": "<PAD>"})
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825

    @require_torch
    def test_torch_encode_plus_sent_to_model(self):
        from transformers import MODEL_MAPPING, TOKENIZER_MAPPING

        MODEL_TOKENIZER_MAPPING = merge_model_tokenizer_mappings(MODEL_MAPPING, TOKENIZER_MAPPING)

        tokenizer = self.get_tokenizer()

        if tokenizer.__class__ not in MODEL_TOKENIZER_MAPPING:
            return

        config_class, model_class = MODEL_TOKENIZER_MAPPING[tokenizer.__class__]
        config = config_class()

        if config.is_encoder_decoder or config.pad_token_id is None:
            return

        model = model_class(config)

        # Make sure the model contains at least the full vocabulary size in its embedding matrix
        is_using_common_embeddings = hasattr(model.get_input_embeddings(), "weight")
        assert (model.get_input_embeddings().weight.shape[0] >= len(tokenizer)) if is_using_common_embeddings else True

        # Build sequence
        first_ten_tokens = list(tokenizer.get_vocab().keys())[:10]
        sequence = " ".join(first_ten_tokens)
        encoded_sequence = tokenizer.encode_plus(sequence, return_tensors="pt")
        batch_encoded_sequence = tokenizer.batch_encode_plus([sequence, sequence], return_tensors="pt")
        # This should not fail
        model(**encoded_sequence)
        model(**batch_encoded_sequence)

        if self.test_rust_tokenizer:
            fast_tokenizer = self.get_rust_tokenizer()
            encoded_sequence_fast = fast_tokenizer.encode_plus(sequence, return_tensors="pt")
            batch_encoded_sequence_fast = fast_tokenizer.batch_encode_plus([sequence, sequence], return_tensors="pt")
            # This should not fail
            model(**encoded_sequence_fast)
            model(**batch_encoded_sequence_fast)

    @require_tf
    def test_tf_encode_plus_sent_to_model(self):
        from transformers import TF_MODEL_MAPPING, TOKENIZER_MAPPING

        MODEL_TOKENIZER_MAPPING = merge_model_tokenizer_mappings(TF_MODEL_MAPPING, TOKENIZER_MAPPING)

        tokenizer = self.get_tokenizer()

        if tokenizer.__class__ not in MODEL_TOKENIZER_MAPPING:
            return

        config_class, model_class = MODEL_TOKENIZER_MAPPING[tokenizer.__class__]
        config = config_class()

        if config.is_encoder_decoder or config.pad_token_id is None:
            return

        model = model_class(config)

        # Make sure the model contains at least the full vocabulary size in its embedding matrix
        assert model.config.vocab_size >= len(tokenizer)

        # Build sequence
        first_ten_tokens = list(tokenizer.get_vocab().keys())[:10]
        sequence = " ".join(first_ten_tokens)
        encoded_sequence = tokenizer.encode_plus(sequence, return_tensors="tf")
        batch_encoded_sequence = tokenizer.batch_encode_plus([sequence, sequence], return_tensors="tf")

        # This should not fail
        model(encoded_sequence)
        model(batch_encoded_sequence)

        if self.test_rust_tokenizer:
            fast_tokenizer = self.get_rust_tokenizer()
            encoded_sequence_fast = fast_tokenizer.encode_plus(sequence, return_tensors="tf")
            batch_encoded_sequence_fast = fast_tokenizer.batch_encode_plus([sequence, sequence], return_tensors="tf")
            # This should not fail
            model(encoded_sequence_fast)
            model(batch_encoded_sequence_fast)