test_tokenization_fast.py 42.4 KB
Newer Older
1
import logging
2
import unittest
Funtowicz Morgan's avatar
Funtowicz Morgan committed
3
4
from collections import namedtuple
from itertools import takewhile
5
6
7
8
9
10
11
12
13
14
15
16

from transformers import (
    BertTokenizer,
    BertTokenizerFast,
    DistilBertTokenizer,
    GPT2Tokenizer,
    GPT2TokenizerFast,
    OpenAIGPTTokenizer,
    PreTrainedTokenizer,
    RobertaTokenizer,
    is_torch_available,
)
Sam Shleifer's avatar
Sam Shleifer committed
17
from transformers.testing_utils import get_tests_dir
18
19
20
21
22
from transformers.tokenization_distilbert import DistilBertTokenizerFast
from transformers.tokenization_openai import OpenAIGPTTokenizerFast
from transformers.tokenization_roberta import RobertaTokenizerFast


23
24
logger = logging.getLogger(__name__)

Funtowicz Morgan's avatar
Funtowicz Morgan committed
25
NON_ENGLISH_TAGS = ["chinese", "dutch", "french", "finnish", "german", "multilingual"]
26
Tokenizer = namedtuple("Tokenizer", ["name", "rust_cls", "python_cls", "vocab_key", "filter", "kwargs"])
Funtowicz Morgan's avatar
Funtowicz Morgan committed
27

28

Funtowicz Morgan's avatar
Funtowicz Morgan committed
29
30
31
def filter_non_english(_: Tokenizer, pretrained_name: str):
    """ Filter all the model for non-english language """
    return not any([lang in pretrained_name for lang in NON_ENGLISH_TAGS])
32
33


Funtowicz Morgan's avatar
Funtowicz Morgan committed
34
35
def filter_roberta_detectors(_: Tokenizer, pretrained_name: str):
    return "detector" not in pretrained_name
36
37


Funtowicz Morgan's avatar
Funtowicz Morgan committed
38
class CommonFastTokenizerTest(unittest.TestCase):
39

Funtowicz Morgan's avatar
Funtowicz Morgan committed
40
41
42
    TOKENIZERS_CLASSES = frozenset([])

    def setUp(self) -> None:
43
        with open(f"{get_tests_dir()}/fixtures/sample_text.txt", encoding="utf-8") as f_data:
Funtowicz Morgan's avatar
Funtowicz Morgan committed
44
            self._data = f_data.read().replace("\n\n", "\n").strip()
45

Funtowicz Morgan's avatar
Funtowicz Morgan committed
46
47
48
49
50
51
52
53
54
    def test_all_tokenizers(self):
        for tok_case in self.TOKENIZERS_CLASSES:
            for pretrained_name in tok_case.python_cls.pretrained_vocab_files_map[tok_case.vocab_key].keys():

                # Tokenizer.filter makes it possible to filter which Tokenizer to case based on all the
                # information available in Tokenizer (name, rust class, python class, vocab key name)
                if tok_case.filter is None or (
                    tok_case.filter is not None and tok_case.filter(tok_case, pretrained_name)
                ):
55
                    kwargs = dict(t for t in tok_case.kwargs) if tok_case.kwargs else {}
Funtowicz Morgan's avatar
Funtowicz Morgan committed
56
                    with self.subTest("{} ({})".format(tok_case.name, pretrained_name)):
57
58
                        tokenizer_r = tok_case.rust_cls.from_pretrained(pretrained_name, **kwargs)
                        tokenizer_p = tok_case.python_cls.from_pretrained(pretrained_name, **kwargs)
Funtowicz Morgan's avatar
Funtowicz Morgan committed
59

60
                        self.fast_align_python(tokenizer_r, tokenizer_p, tok_case, pretrained_name)
Funtowicz Morgan's avatar
Funtowicz Morgan committed
61
62
                        self.fast_only(tokenizer_r)

63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
    def test_pretokenized_tokenizers(self):
        for tok_case in self.TOKENIZERS_CLASSES:
            for pretrained_name in tok_case.python_cls.pretrained_vocab_files_map[tok_case.vocab_key].keys():

                # Tokenizer.filter makes it possible to filter which Tokenizer to case based on all the
                # information available in Tokenizer (name, rust class, python class, vocab key name)
                if tok_case.filter is None or (
                    tok_case.filter is not None and tok_case.filter(tok_case, pretrained_name)
                ):
                    with self.subTest("{} ({})".format(tok_case.name, pretrained_name)):
                        tokenizer_r = tok_case.rust_cls.from_pretrained(pretrained_name, add_prefix_space=True)
                        tokenizer_p = tok_case.python_cls.from_pretrained(pretrained_name, add_prefix_space=True)

                        self.assert_pretokenized_inputs(tokenizer_r, tokenizer_p)

78
    def fast_align_python(self, tokenizer_r, tokenizer_p, tok_case, pretrained_name):
Funtowicz Morgan's avatar
Funtowicz Morgan committed
79
80
81
82
83
84
85
86
87
88
89
        # Check is_fast is set correctly
        self.assertFalse(tokenizer_p.is_fast)
        self.assertTrue(tokenizer_r.is_fast)

        # Check that Rust and Python align
        self.assert_tokenization_python_rust_equals(tokenizer_r, tokenizer_p)
        self.assert_num_special_tokens_to_add_equal(tokenizer_r, tokenizer_p)
        self.assert_max_length_equal(tokenizer_r, tokenizer_p)
        self.assert_special_tokens_map_equal(tokenizer_r, tokenizer_p)
        self.assert_embeded_special_tokens(tokenizer_r, tokenizer_p)
        self.assert_padding(tokenizer_r, tokenizer_p)
90
        self.assert_create_token_type_ids(tokenizer_r, tokenizer_p)
91
        self.assert_prepare_for_model(tokenizer_r, tokenizer_p)
Funtowicz Morgan's avatar
Funtowicz Morgan committed
92
93
94
95
96
97
98
99
100
101
102

    def fast_only(self, tokenizer_r):
        # Ensure None raise an error
        self.assertRaises(ValueError, tokenizer_r.tokenize, None)
        self.assertRaises(ValueError, tokenizer_r.encode, None)
        self.assertRaises(ValueError, tokenizer_r.encode_plus, None)
        self.assertRaises(ValueError, tokenizer_r.batch_encode_plus, None)

        self.assert_add_tokens(tokenizer_r)
        self.assert_offsets_mapping(tokenizer_r)
        self.assert_add_special_tokens(tokenizer_r)
103
        self.assert_alignement_methods(tokenizer_r)
104
        self.assert_batch_encode_dynamic_overflowing(tokenizer_r)
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182

    def assert_alignement_methods(self, tokenizer_r):
        words = ["Wonderful", "no", "inspiration", "example", "with", "subtoken"]
        text = " ".join(words)
        batch_size = 3

        encoding = tokenizer_r.encode_plus(text, add_special_tokens=False)

        batch_encoding = tokenizer_r.batch_encode_plus([text] * batch_size, add_special_tokens=False)
        num_tokens = len(encoding["input_ids"])

        last_word_index = len(words) - 1
        last_token_index = num_tokens - 1
        last_batch_index = batch_size - 1
        last_char_index = len(text) - 1

        # words, tokens
        self.assertEqual(len(encoding.words(0)), num_tokens)
        self.assertEqual(max(encoding.words(0)), last_word_index)
        self.assertEqual(min(encoding.words(0)), 0)
        self.assertEqual(len(batch_encoding.words(last_batch_index)), num_tokens)
        self.assertEqual(max(batch_encoding.words(last_batch_index)), last_word_index)
        self.assertEqual(min(batch_encoding.words(last_batch_index)), 0)
        self.assertEqual(len(encoding.tokens(0)), num_tokens)

        # Assert token_to_word
        self.assertEqual(encoding.token_to_word(0), 0)
        self.assertEqual(encoding.token_to_word(0, 0), 0)
        self.assertEqual(encoding.token_to_word(last_token_index), last_word_index)
        self.assertEqual(encoding.token_to_word(0, last_token_index), last_word_index)
        self.assertEqual(batch_encoding.token_to_word(1, 0), 0)
        self.assertEqual(batch_encoding.token_to_word(0, last_token_index), last_word_index)
        self.assertEqual(batch_encoding.token_to_word(last_batch_index, last_token_index), last_word_index)

        # Assert word_to_tokens
        self.assertEqual(encoding.word_to_tokens(0).start, 0)
        self.assertEqual(encoding.word_to_tokens(0, 0).start, 0)
        self.assertEqual(encoding.word_to_tokens(last_word_index).end, last_token_index + 1)
        self.assertEqual(encoding.word_to_tokens(0, last_word_index).end, last_token_index + 1)
        self.assertEqual(batch_encoding.word_to_tokens(1, 0).start, 0)
        self.assertEqual(batch_encoding.word_to_tokens(0, last_word_index).end, last_token_index + 1)
        self.assertEqual(batch_encoding.word_to_tokens(last_batch_index, last_word_index).end, last_token_index + 1)

        # Assert token_to_chars
        self.assertEqual(encoding.token_to_chars(0).start, 0)
        self.assertEqual(encoding.token_to_chars(0, 0).start, 0)
        self.assertEqual(encoding.token_to_chars(last_token_index).end, last_char_index + 1)
        self.assertEqual(encoding.token_to_chars(0, last_token_index).end, last_char_index + 1)
        self.assertEqual(batch_encoding.token_to_chars(1, 0).start, 0)
        self.assertEqual(batch_encoding.token_to_chars(0, last_token_index).end, last_char_index + 1)
        self.assertEqual(batch_encoding.token_to_chars(last_batch_index, last_token_index).end, last_char_index + 1)

        # Assert char_to_token
        self.assertEqual(encoding.char_to_token(0), 0)
        self.assertEqual(encoding.char_to_token(0, 0), 0)
        self.assertEqual(encoding.char_to_token(last_char_index), last_token_index)
        self.assertEqual(encoding.char_to_token(0, last_char_index), last_token_index)
        self.assertEqual(batch_encoding.char_to_token(1, 0), 0)
        self.assertEqual(batch_encoding.char_to_token(0, last_char_index), last_token_index)
        self.assertEqual(batch_encoding.char_to_token(last_batch_index, last_char_index), last_token_index)

        # Assert char_to_word
        self.assertEqual(encoding.char_to_word(0), 0)
        self.assertEqual(encoding.char_to_word(0, 0), 0)
        self.assertEqual(encoding.char_to_word(last_char_index), last_word_index)
        self.assertEqual(encoding.char_to_word(0, last_char_index), last_word_index)
        self.assertEqual(batch_encoding.char_to_word(1, 0), 0)
        self.assertEqual(batch_encoding.char_to_word(0, last_char_index), last_word_index)
        self.assertEqual(batch_encoding.char_to_word(last_batch_index, last_char_index), last_word_index)

        # Assert word_to_chars
        self.assertEqual(encoding.word_to_chars(0).start, 0)
        self.assertEqual(encoding.word_to_chars(0, 0).start, 0)
        self.assertEqual(encoding.word_to_chars(last_word_index).end, last_char_index + 1)
        self.assertEqual(encoding.word_to_chars(0, last_word_index).end, last_char_index + 1)
        self.assertEqual(batch_encoding.word_to_chars(1, 0).start, 0)
        self.assertEqual(batch_encoding.word_to_chars(0, last_word_index).end, last_char_index + 1)
        self.assertEqual(batch_encoding.word_to_chars(last_batch_index, last_word_index).end, last_char_index + 1)
Funtowicz Morgan's avatar
Funtowicz Morgan committed
183

184
    def assert_tokenization_python_rust_equals(self, tokenizer_r, tokenizer_p):
185
186
187
188
189
        # Ensure basic input match
        input_p = tokenizer_p.encode_plus(self._data)
        input_r = tokenizer_r.encode_plus(self._data)

        for key in filter(lambda x: x in ["input_ids", "token_type_ids", "attention_mask"], input_p.keys()):
Funtowicz Morgan's avatar
Funtowicz Morgan committed
190
            self.assertSequenceEqual(input_p[key], input_r[key])
191
192
193
194
195

        input_pairs_p = tokenizer_p.encode_plus(self._data, self._data)
        input_pairs_r = tokenizer_r.encode_plus(self._data, self._data)

        for key in filter(lambda x: x in ["input_ids", "token_type_ids", "attention_mask"], input_p.keys()):
Funtowicz Morgan's avatar
Funtowicz Morgan committed
196
            self.assertSequenceEqual(input_pairs_p[key], input_pairs_r[key])
197
198

        # Ensure truncation match
199
200
        input_p = tokenizer_p.encode_plus(self._data, max_length=512, truncation=True)
        input_r = tokenizer_r.encode_plus(self._data, max_length=512, truncation=True)
201
202

        for key in filter(lambda x: x in ["input_ids", "token_type_ids", "attention_mask"], input_p.keys()):
Funtowicz Morgan's avatar
Funtowicz Morgan committed
203
            self.assertSequenceEqual(input_p[key], input_r[key])
204
205

        # Ensure truncation with stride match
206
207
208
209
210
211
        input_p = tokenizer_p.encode_plus(
            self._data, max_length=512, truncation=True, stride=3, return_overflowing_tokens=True
        )
        input_r = tokenizer_r.encode_plus(
            self._data, max_length=512, truncation=True, stride=3, return_overflowing_tokens=True
        )
212
213

        for key in filter(lambda x: x in ["input_ids", "token_type_ids", "attention_mask"], input_p.keys()):
214
            self.assertSequenceEqual(input_p[key], input_r[key][0])
Funtowicz Morgan's avatar
Funtowicz Morgan committed
215
216
217
218
219
220
221
222
223
224
225
226
227
228

    def assert_num_special_tokens_to_add_equal(self, tokenizer_r, tokenizer_p):
        # Check we have the same number of added_tokens for both pair and non-pair inputs.
        self.assertEqual(tokenizer_r.num_special_tokens_to_add(False), tokenizer_p.num_special_tokens_to_add(False))
        self.assertEqual(tokenizer_r.num_special_tokens_to_add(True), tokenizer_p.num_special_tokens_to_add(True))

    def assert_max_length_equal(self, tokenizer_r, tokenizer_p):
        # Check we have the correct max_length for both pair and non-pair inputs.
        self.assertEqual(tokenizer_r.max_len_single_sentence, tokenizer_p.max_len_single_sentence)
        self.assertEqual(tokenizer_r.max_len_sentences_pair, tokenizer_p.max_len_sentences_pair)

    def assert_special_tokens_map_equal(self, tokenizer_r, tokenizer_p):
        # Assert the set of special tokens match.
        self.assertSequenceEqual(
Lysandre's avatar
Lysandre committed
229
230
            tokenizer_p.special_tokens_map.items(),
            tokenizer_r.special_tokens_map.items(),
231
232
        )

233
234
235
236
237
238
239
240
    def assert_add_tokens(self, tokenizer_r):
        vocab_size = tokenizer_r.vocab_size
        self.assertEqual(tokenizer_r.add_tokens(""), 0)
        self.assertEqual(tokenizer_r.add_tokens("testoken"), 1)
        self.assertEqual(tokenizer_r.add_tokens(["testoken1", "testtoken2"]), 2)
        self.assertEqual(len(tokenizer_r), vocab_size + 3)

        self.assertEqual(tokenizer_r.add_special_tokens({}), 0)
241
        self.assertEqual(tokenizer_r.add_special_tokens({"bos_token": "[BOS]", "eos_token": "[EOS]"}), 2)
242
243
244
245
246
247
248
        self.assertRaises(
            AssertionError, tokenizer_r.add_special_tokens, {"additional_special_tokens": "<testtoken1>"}
        )
        self.assertEqual(tokenizer_r.add_special_tokens({"additional_special_tokens": ["<testtoken2>"]}), 1)
        self.assertEqual(
            tokenizer_r.add_special_tokens({"additional_special_tokens": ["<testtoken3>", "<testtoken4>"]}), 2
        )
249
        self.assertEqual(len(tokenizer_r), vocab_size + 8)
250

Funtowicz Morgan's avatar
Funtowicz Morgan committed
251
    def assert_offsets_mapping(self, tokenizer_r):
252
253
254
255
        text = "Wonderful no inspiration example with subtoken"
        pair = "Along with an awesome pair"

        # No pair
Funtowicz Morgan's avatar
Funtowicz Morgan committed
256
257
258
259
        tokens_with_offsets = tokenizer_r.encode_plus(
            text, return_special_tokens_mask=True, return_offsets_mapping=True, add_special_tokens=True
        )
        added_tokens = tokenizer_r.num_special_tokens_to_add(False)
260
261
262
263
264
265
266
267
268
        offsets = tokens_with_offsets["offset_mapping"]

        # Assert there is the same number of tokens and offsets
        self.assertEqual(len(offsets), len(tokens_with_offsets["input_ids"]))

        # Assert there is online added_tokens special_tokens
        self.assertEqual(sum(tokens_with_offsets["special_tokens_mask"]), added_tokens)

        # Pairs
Funtowicz Morgan's avatar
Funtowicz Morgan committed
269
270
        tokens_with_offsets = tokenizer_r.encode_plus(
            text, pair, return_special_tokens_mask=True, return_offsets_mapping=True, add_special_tokens=True
271
        )
Funtowicz Morgan's avatar
Funtowicz Morgan committed
272
        added_tokens = tokenizer_r.num_special_tokens_to_add(True)
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
        offsets = tokens_with_offsets["offset_mapping"]

        # Assert there is the same number of tokens and offsets
        self.assertEqual(len(offsets), len(tokens_with_offsets["input_ids"]))

        # Assert there is online added_tokens special_tokens
        self.assertEqual(sum(tokens_with_offsets["special_tokens_mask"]), added_tokens)

    def assert_batch_encode_dynamic_overflowing(self, tokenizer: PreTrainedTokenizer):
        """
        When calling batch_encode with multiple sequence it can returns different number of
        overflowing encoding for each sequence:
        [
          Sequence 1: [Encoding 1, Encoding 2],
          Sequence 2: [Encoding 1],
          Sequence 3: [Encoding 1, Encoding 2, ... Encoding N]
        ]
        This needs to be padded so that it can represented as a tensor
        """
        returned_tensor = "pt" if is_torch_available() else "tf"

294
295
296
        if not tokenizer.pad_token or tokenizer.pad_token_id < 0:
            return

297
298
299
        tokens = tokenizer.encode_plus(
            "HuggingFace is solving NLP one commit at a time",
            max_length=6,
300
301
            padding=True,
            truncation=True,
302
303
304
305
306
307
308
309
310
311
312
            return_tensors=returned_tensor,
            return_overflowing_tokens=True,
        )

        for key in filter(lambda x: "overflow_to_sample_mapping" not in x, tokens.keys()):
            self.assertEqual(len(tokens[key].shape), 2)

        # Mono sample
        tokens = tokenizer.batch_encode_plus(
            ["HuggingFace is solving NLP one commit at a time"],
            max_length=6,
313
314
            padding=True,
            truncation="only_first",
315
316
317
318
319
320
321
322
323
324
325
326
            return_tensors=returned_tensor,
            return_overflowing_tokens=True,
        )

        for key in filter(lambda x: "overflow_to_sample_mapping" not in x, tokens.keys()):
            self.assertEqual(len(tokens[key].shape), 2)
            self.assertEqual(tokens[key].shape[-1], 6)

        # Multi sample
        tokens = tokenizer.batch_encode_plus(
            ["HuggingFace is solving NLP one commit at a time", "Very tiny input"],
            max_length=6,
327
328
            padding=True,
            truncation="only_first",
329
330
331
332
333
334
335
336
            return_tensors=returned_tensor,
            return_overflowing_tokens=True,
        )

        for key in filter(lambda x: "overflow_to_sample_mapping" not in x, tokens.keys()):
            self.assertEqual(len(tokens[key].shape), 2)
            self.assertEqual(tokens[key].shape[-1], 6)

337
338
339
340
341
342
    def assert_pretokenized_inputs(self, tokenizer_r, tokenizer_p):
        # Input string
        pretokenized_input_simple = "This is a sample input".split()
        pretokenized_input_pair = "This is a sample pair".split()

        # Test encode for pretokenized inputs
343
344
        output_r = tokenizer_r.encode(pretokenized_input_simple, is_split_into_words=True)
        output_p = tokenizer_p.encode(pretokenized_input_simple, is_split_into_words=True)
345
346
347
        self.assertEqual(output_p, output_r)

        kwargs = {
348
            "is_split_into_words": True,
349
350
351
352
353
354
            "return_token_type_ids": True,
            "return_attention_mask": True,
            "return_overflowing_tokens": False,
            "return_special_tokens_mask": True,
            "return_offsets_mapping": False,  # Not implemented in python tokenizers
        }
355
        batch_kwargs = {
356
            "is_split_into_words": True,
357
358
359
360
361
362
            "return_token_type_ids": True,
            "return_attention_mask": True,  # we have an 's' here
            "return_overflowing_tokens": False,
            "return_special_tokens_mask": True,  # we have an 's' here
            "return_offsets_mapping": False,  # Not implemented in python tokenizers
        }
363
364
365
366
367
368
369
370
        # Test encode_plus for pretokenized inputs
        output_r = tokenizer_r.encode_plus(pretokenized_input_simple, **kwargs)
        output_p = tokenizer_p.encode_plus(pretokenized_input_simple, **kwargs)
        for key in output_p.keys():
            self.assertEqual(output_p[key], output_r[key])

        # Test batch_encode_plus for pretokenized inputs
        input_batch = ([pretokenized_input_simple] * 2) + [pretokenized_input_simple + pretokenized_input_pair]
371
372
        output_r = tokenizer_r.batch_encode_plus(input_batch, **batch_kwargs)
        output_p = tokenizer_p.batch_encode_plus(input_batch, **batch_kwargs)
373
374
375
376
        for key in output_p.keys():
            self.assertEqual(output_p[key], output_r[key])

        # Test encode for pretokenized inputs pairs
377
378
        output_r = tokenizer_r.encode(pretokenized_input_simple, pretokenized_input_pair, is_split_into_words=True)
        output_p = tokenizer_p.encode(pretokenized_input_simple, pretokenized_input_pair, is_split_into_words=True)
379
380
381
382
383
384
385
386
387
388
389
390
391
        self.assertEqual(output_p, output_r)

        # Test encode_plus for pretokenized inputs
        output_r = tokenizer_r.encode_plus(pretokenized_input_simple, pretokenized_input_pair, **kwargs)
        output_p = tokenizer_p.encode_plus(pretokenized_input_simple, pretokenized_input_pair, **kwargs)
        for key in output_p.keys():
            self.assertEqual(output_p[key], output_r[key])

        # Test batch_encode_plus for pretokenized inputs
        input_batch_pair = ([pretokenized_input_simple, pretokenized_input_pair] * 2) + [
            pretokenized_input_simple + pretokenized_input_pair,
            pretokenized_input_pair,
        ]
392
393
        output_r = tokenizer_r.batch_encode_plus(input_batch_pair, **batch_kwargs)
        output_p = tokenizer_p.batch_encode_plus(input_batch_pair, **batch_kwargs)
394
395
396
        for key in output_p.keys():
            self.assertEqual(output_p[key], output_r[key])

397
398
399
400
401
402
403
404
405
406
407
408
409
410
    def assert_create_token_type_ids(self, tokenizer_r, tokenizer_p):
        input_simple = [1, 2, 3]
        input_pair = [1, 2, 3]

        # Generate output
        output_r = tokenizer_r.create_token_type_ids_from_sequences(input_simple)
        output_p = tokenizer_p.create_token_type_ids_from_sequences(input_simple)
        self.assertEqual(output_p, output_r)

        # Generate pair output
        output_r = tokenizer_r.create_token_type_ids_from_sequences(input_simple, input_pair)
        output_p = tokenizer_p.create_token_type_ids_from_sequences(input_simple, input_pair)
        self.assertEqual(output_p, output_r)

411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
    def assert_build_inputs_with_special_tokens(self, tokenizer_r, tokenizer_p):
        # Input string
        input_simple = tokenizer_p.tokenize("This is a sample input")
        input_pair = tokenizer_p.tokenize("This is a sample pair")

        # Generate output
        output_r = tokenizer_r.build_inputs_with_special_tokens(input_simple)
        output_p = tokenizer_p.build_inputs_with_special_tokens(input_simple)
        self.assertEqual(output_p, output_r)

        # Generate pair output
        output_r = tokenizer_r.build_inputs_with_special_tokens(input_simple, input_pair)
        output_p = tokenizer_p.build_inputs_with_special_tokens(input_simple, input_pair)
        self.assertEqual(output_p, output_r)

        # Input tokens id
        input_simple = tokenizer_p.encode("This is a sample input")
        input_pair = tokenizer_p.encode("This is a sample pair")

        # Generate output
        output_r = tokenizer_r.build_inputs_with_special_tokens(input_simple)
        output_p = tokenizer_p.build_inputs_with_special_tokens(input_simple)
        self.assertEqual(output_p, output_r)

        # Generate pair output
        output_r = tokenizer_r.build_inputs_with_special_tokens(input_simple, input_pair)
        output_p = tokenizer_p.build_inputs_with_special_tokens(input_simple, input_pair)
        self.assertEqual(output_p, output_r)

Funtowicz Morgan's avatar
Funtowicz Morgan committed
440
441
    def assert_padding(self, tokenizer_r, tokenizer_p, max_length=15):
        def assert_padded_input_match(input_r: list, input_p: list, max_length: int):
442

Funtowicz Morgan's avatar
Funtowicz Morgan committed
443
            # Ensure we match max_length
444
445
            self.assertEqual(len(input_r), max_length)
            self.assertEqual(len(input_p), max_length)
446

Funtowicz Morgan's avatar
Funtowicz Morgan committed
447
448
449
450
            # Ensure the number of padded tokens is the same
            padded_tokens_r = list(takewhile(lambda i: i == tokenizer_r.pad_token_id, reversed(input_r)))
            padded_tokens_p = list(takewhile(lambda i: i == tokenizer_p.pad_token_id, reversed(input_p)))
            self.assertSequenceEqual(padded_tokens_r, padded_tokens_p)
451

452
        def assert_batch_padded_input_match(input_r: dict, input_p: dict, max_length: int):
Funtowicz Morgan's avatar
Funtowicz Morgan committed
453
            for i_r in input_r.values():
454
455
456
457
458
459
                self.assertEqual(len(i_r), 2), self.assertEqual(len(i_r[0]), max_length), self.assertEqual(
                    len(i_r[1]), max_length
                )
                self.assertEqual(len(i_r), 2), self.assertEqual(len(i_r[0]), max_length), self.assertEqual(
                    len(i_r[1]), max_length
                )
460

Funtowicz Morgan's avatar
Funtowicz Morgan committed
461
462
            for i_r, i_p in zip(input_r["input_ids"], input_p["input_ids"]):
                assert_padded_input_match(i_r, i_p, max_length)
463

Funtowicz Morgan's avatar
Funtowicz Morgan committed
464
465
            for i_r, i_p in zip(input_r["attention_mask"], input_p["attention_mask"]):
                self.assertSequenceEqual(i_r, i_p)
466

467
        # Encode - Simple input
Funtowicz Morgan's avatar
Funtowicz Morgan committed
468
469
470
        input_r = tokenizer_r.encode("This is a simple input", max_length=max_length, pad_to_max_length=True)
        input_p = tokenizer_p.encode("This is a simple input", max_length=max_length, pad_to_max_length=True)
        assert_padded_input_match(input_r, input_p, max_length)
471
472
473
        input_r = tokenizer_r.encode("This is a simple input", max_length=max_length, padding="max_length")
        input_p = tokenizer_p.encode("This is a simple input", max_length=max_length, padding="max_length")
        assert_padded_input_match(input_r, input_p, max_length)
474

475
476
477
478
479
        input_r = tokenizer_r.encode("This is a simple input", padding="longest")
        input_p = tokenizer_p.encode("This is a simple input", padding=True)
        assert_padded_input_match(input_r, input_p, len(input_r))

        # Encode - Pair input
Funtowicz Morgan's avatar
Funtowicz Morgan committed
480
481
482
483
484
485
486
        input_r = tokenizer_r.encode(
            "This is a simple input", "This is a pair", max_length=max_length, pad_to_max_length=True
        )
        input_p = tokenizer_p.encode(
            "This is a simple input", "This is a pair", max_length=max_length, pad_to_max_length=True
        )
        assert_padded_input_match(input_r, input_p, max_length)
487
488
489
490
491
492
493
494
495
496
        input_r = tokenizer_r.encode(
            "This is a simple input", "This is a pair", max_length=max_length, padding="max_length"
        )
        input_p = tokenizer_p.encode(
            "This is a simple input", "This is a pair", max_length=max_length, padding="max_length"
        )
        assert_padded_input_match(input_r, input_p, max_length)
        input_r = tokenizer_r.encode("This is a simple input", "This is a pair", padding=True)
        input_p = tokenizer_p.encode("This is a simple input", "This is a pair", padding="longest")
        assert_padded_input_match(input_r, input_p, len(input_r))
497

498
        # Encode_plus - Simple input
Funtowicz Morgan's avatar
Funtowicz Morgan committed
499
500
501
502
        input_r = tokenizer_r.encode_plus("This is a simple input", max_length=max_length, pad_to_max_length=True)
        input_p = tokenizer_p.encode_plus("This is a simple input", max_length=max_length, pad_to_max_length=True)
        assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length)
        self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
503
504
505
506
        input_r = tokenizer_r.encode_plus("This is a simple input", max_length=max_length, padding="max_length")
        input_p = tokenizer_p.encode_plus("This is a simple input", max_length=max_length, padding="max_length")
        assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length)
        self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
507

508
509
510
511
512
513
514
        input_r = tokenizer_r.encode_plus("This is a simple input", padding="longest")
        input_p = tokenizer_p.encode_plus("This is a simple input", padding=True)
        assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], len(input_r["input_ids"]))

        self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])

        # Encode_plus - Pair input
Funtowicz Morgan's avatar
Funtowicz Morgan committed
515
516
517
518
519
520
521
522
        input_r = tokenizer_r.encode_plus(
            "This is a simple input", "This is a pair", max_length=max_length, pad_to_max_length=True
        )
        input_p = tokenizer_p.encode_plus(
            "This is a simple input", "This is a pair", max_length=max_length, pad_to_max_length=True
        )
        assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length)
        self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
523
524
525
526
527
528
529
530
531
532
533
534
        input_r = tokenizer_r.encode_plus(
            "This is a simple input", "This is a pair", max_length=max_length, padding="max_length"
        )
        input_p = tokenizer_p.encode_plus(
            "This is a simple input", "This is a pair", max_length=max_length, padding="max_length"
        )
        assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length)
        self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
        input_r = tokenizer_r.encode_plus("This is a simple input", "This is a pair", padding="longest")
        input_p = tokenizer_p.encode_plus("This is a simple input", "This is a pair", padding=True)
        assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], len(input_r["input_ids"]))
        self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
535

536
        # Batch_encode_plus - Simple input
Funtowicz Morgan's avatar
Funtowicz Morgan committed
537
538
539
540
541
542
        input_r = tokenizer_r.batch_encode_plus(
            ["This is a simple input 1", "This is a simple input 2"], max_length=max_length, pad_to_max_length=True
        )
        input_p = tokenizer_p.batch_encode_plus(
            ["This is a simple input 1", "This is a simple input 2"], max_length=max_length, pad_to_max_length=True
        )
543
544
545
        assert_batch_padded_input_match(input_r, input_p, max_length)

        input_r = tokenizer_r.batch_encode_plus(
Lysandre's avatar
Lysandre committed
546
547
548
            ["This is a simple input 1", "This is a simple input 2"],
            max_length=max_length,
            padding="max_length",
549
550
        )
        input_p = tokenizer_p.batch_encode_plus(
Lysandre's avatar
Lysandre committed
551
552
553
            ["This is a simple input 1", "This is a simple input 2"],
            max_length=max_length,
            padding="max_length",
554
555
556
557
        )
        assert_batch_padded_input_match(input_r, input_p, max_length)

        input_r = tokenizer_r.batch_encode_plus(
Lysandre's avatar
Lysandre committed
558
559
560
            ["This is a simple input 1", "This is a simple input 2"],
            max_length=max_length,
            padding="longest",
561
562
        )
        input_p = tokenizer_p.batch_encode_plus(
Lysandre's avatar
Lysandre committed
563
564
565
            ["This is a simple input 1", "This is a simple input 2"],
            max_length=max_length,
            padding=True,
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
        )
        assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]))

        input_r = tokenizer_r.batch_encode_plus(
            ["This is a simple input 1", "This is a simple input 2"], padding="longest"
        )
        input_p = tokenizer_p.batch_encode_plus(["This is a simple input 1", "This is a simple input 2"], padding=True)
        assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]))

        # Batch_encode_plus - Pair input
        input_r = tokenizer_r.batch_encode_plus(
            [
                ("This is a simple input 1", "This is a simple input 2"),
                ("This is a simple pair 1", "This is a simple pair 2"),
            ],
            max_length=max_length,
            truncation=True,
            padding="max_length",
        )
        input_p = tokenizer_p.batch_encode_plus(
            [
                ("This is a simple input 1", "This is a simple input 2"),
                ("This is a simple pair 1", "This is a simple pair 2"),
            ],
            max_length=max_length,
            truncation=True,
            padding="max_length",
        )
        assert_batch_padded_input_match(input_r, input_p, max_length)
595

Funtowicz Morgan's avatar
Funtowicz Morgan committed
596
597
598
599
600
        input_r = tokenizer_r.batch_encode_plus(
            [
                ("This is a simple input 1", "This is a simple input 2"),
                ("This is a simple pair 1", "This is a simple pair 2"),
            ],
601
            padding=True,
Funtowicz Morgan's avatar
Funtowicz Morgan committed
602
603
604
605
606
607
        )
        input_p = tokenizer_p.batch_encode_plus(
            [
                ("This is a simple input 1", "This is a simple input 2"),
                ("This is a simple pair 1", "This is a simple pair 2"),
            ],
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
            padding="longest",
        )
        assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]))

        # Using pad on single examples after tokenization
        input_r = tokenizer_r.encode_plus("This is a input 1")
        input_r = tokenizer_r.pad(input_r)

        input_p = tokenizer_r.encode_plus("This is a input 1")
        input_p = tokenizer_r.pad(input_p)

        assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], len(input_r["input_ids"]))

        # Using pad on single examples after tokenization
        input_r = tokenizer_r.encode_plus("This is a input 1")
        input_r = tokenizer_r.pad(input_r, max_length=max_length, padding="max_length")

        input_p = tokenizer_r.encode_plus("This is a input 1")
        input_p = tokenizer_r.pad(input_p, max_length=max_length, padding="max_length")

        assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length)

        # Using pad after tokenization
        input_r = tokenizer_r.batch_encode_plus(
            ["This is a input 1", "This is a much longer input whilch should be padded"]
Funtowicz Morgan's avatar
Funtowicz Morgan committed
633
        )
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
        input_r = tokenizer_r.pad(input_r)

        input_p = tokenizer_r.batch_encode_plus(
            ["This is a input 1", "This is a much longer input whilch should be padded"]
        )
        input_p = tokenizer_r.pad(input_p)

        assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]))

        # Using pad after tokenization
        input_r = tokenizer_r.batch_encode_plus(
            ["This is a input 1", "This is a much longer input whilch should be padded"]
        )
        input_r = tokenizer_r.pad(input_r, max_length=max_length, padding="max_length")

        input_p = tokenizer_r.batch_encode_plus(
            ["This is a input 1", "This is a much longer input whilch should be padded"]
        )
        input_p = tokenizer_r.pad(input_p, max_length=max_length, padding="max_length")

        assert_batch_padded_input_match(input_r, input_p, max_length)
655

Funtowicz Morgan's avatar
Funtowicz Morgan committed
656
657
658
    def assert_save_pretrained(self, tokenizer_r, tokenizer_p):
        # Checks it save with the same files
        self.assertSequenceEqual(tokenizer_r.save_vocabulary("."), tokenizer_p.save_vocabulary("."))
659

Funtowicz Morgan's avatar
Funtowicz Morgan committed
660
661
        # Checks everything loads correctly in the same way
        tokenizer_rp, tokenizer_pp = tokenizer_r.from_pretrained("."), tokenizer_p.from_pretrained(".")
662

Funtowicz Morgan's avatar
Funtowicz Morgan committed
663
664
665
666
667
        # Check special tokens are set accordingly on Rust and Python
        for key in tokenizer_pp.special_tokens_map:
            self.assertTrue(hasattr(tokenizer_rp, key))
            # self.assertEqual(getattr(tokenizer_rp, key), getattr(tokenizer_pp, key))
            # self.assertEqual(getattr(tokenizer_rp, key + "_id"), getattr(tokenizer_pp, key + "_id"))
668

Funtowicz Morgan's avatar
Funtowicz Morgan committed
669
670
671
672
673
674
675
676
    def assert_embeded_special_tokens(self, tokenizer_r, tokenizer_p):
        sentence = "A, <mask> AllenNLP sentence."
        tokens_r = tokenizer_r.encode_plus(
            sentence, add_special_tokens=True, return_attention_mask=False, return_token_type_ids=True
        )
        tokens_p = tokenizer_p.encode_plus(
            sentence, add_special_tokens=True, return_attention_mask=False, return_token_type_ids=True
        )
677

Funtowicz Morgan's avatar
Funtowicz Morgan committed
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
        for key in tokens_p.keys():
            self.assertEqual(tokens_r[key], tokens_p[key])

        self.assertEqual(sum(tokens_r["token_type_ids"]), 0)
        self.assertEqual(sum(tokens_p["token_type_ids"]), 0)

        tokens_r = tokenizer_r.convert_ids_to_tokens(tokens_r["input_ids"])
        tokens_p = tokenizer_p.convert_ids_to_tokens(tokens_p["input_ids"])
        self.assertSequenceEqual(tokens_r, tokens_p)

    def assert_add_special_tokens(self, tokenizer_r):
        simple_num_special_tokens_to_add = tokenizer_r.num_special_tokens_to_add(pair=False)
        # pair_num_special_tokens_to_add = tokenizer_r.num_special_tokens_to_add(pair=True)

        for text in ["", " "]:
            # tokenize()
            no_special_tokens = tokenizer_r.tokenize(text, add_special_tokens=False)
            with_special_tokens = tokenizer_r.tokenize(text, add_special_tokens=True)
            self.assertEqual(len(no_special_tokens), len(with_special_tokens) - simple_num_special_tokens_to_add)

            # encode()
            no_special_tokens = tokenizer_r.encode(text, add_special_tokens=False)
            with_special_tokens = tokenizer_r.encode(text, add_special_tokens=True)
            self.assertEqual(len(no_special_tokens), len(with_special_tokens) - simple_num_special_tokens_to_add)

            # encode_plus()
            no_special_tokens = tokenizer_r.encode_plus(text, add_special_tokens=False)
            with_special_tokens = tokenizer_r.encode_plus(text, add_special_tokens=True)
            for key in no_special_tokens.keys():
                self.assertEqual(
                    len(no_special_tokens[key]), len(with_special_tokens[key]) - simple_num_special_tokens_to_add
                )

            # # batch_encode_plus
            no_special_tokens = tokenizer_r.batch_encode_plus([text, text], add_special_tokens=False)
            with_special_tokens = tokenizer_r.batch_encode_plus([text, text], add_special_tokens=True)
            for key in no_special_tokens.keys():
                for i_no, i_with in zip(no_special_tokens[key], with_special_tokens[key]):
                    self.assertEqual(len(i_no), len(i_with) - simple_num_special_tokens_to_add)

718
719
720
721
722
723
    def assert_prepare_for_model(self, tokenizer_r, tokenizer_p):
        string_sequence = "Asserting that both tokenizers are equal"
        python_output = tokenizer_p.prepare_for_model(tokenizer_p.encode(string_sequence))
        rust_output = tokenizer_r.prepare_for_model(tokenizer_r.encode(string_sequence))
        self.assertEqual(python_output, rust_output)

Funtowicz Morgan's avatar
Funtowicz Morgan committed
724
725
726
727
728
729
730
731

class WordPieceFastTokenizerTest(CommonFastTokenizerTest):
    """
    Override all the specific methods to test WordPiece behavior
    """

    TOKENIZERS_CLASSES = frozenset(
        [
732
733
734
735
            Tokenizer("Bert", BertTokenizerFast, BertTokenizer, "vocab_file", filter_non_english, None),
            Tokenizer(
                "DistilBert", DistilBertTokenizerFast, DistilBertTokenizer, "vocab_file", filter_non_english, None
            ),
Funtowicz Morgan's avatar
Funtowicz Morgan committed
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
        ]
    )

    def fast_only(self, tokenizer_r):
        super().fast_only(tokenizer_r)
        self.assert_offsets_with_special_characters(tokenizer_r)

    def assert_add_special_tokens(self, tokenizer_r):
        super().assert_add_special_tokens(tokenizer_r)

    def assert_offsets_with_special_characters(self, tokenizer_r):
        sentence = "A, na茂ve [MASK] AllenNLP sentence."
        tokens = tokenizer_r.encode_plus(
            sentence,
            return_attention_mask=False,
            return_token_type_ids=False,
            return_offsets_mapping=True,
            add_special_tokens=True,
        )
755

Anthony MOI's avatar
Anthony MOI committed
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
        do_lower_case = tokenizer_r.init_kwargs.get("do_lower_case")
        expected_results = (
            [
                ((0, 0), "[CLS]"),
                ((0, 1), "A"),
                ((1, 2), ","),
                ((3, 5), "na"),
                ((5, 6), "##茂"),
                ((6, 8), "##ve"),
                ((9, 15), "[MASK]"),
                ((16, 21), "Allen"),
                ((21, 23), "##NL"),
                ((23, 24), "##P"),
                ((25, 33), "sentence"),
                ((33, 34), "."),
                ((0, 0), "[SEP]"),
            ]
            if not do_lower_case
            else [
                ((0, 0), "[CLS]"),
                ((0, 1), "a"),
                ((1, 2), ","),
                ((3, 8), "naive"),
                ((9, 15), "[MASK]"),
                ((16, 21), "allen"),
                ((21, 23), "##nl"),
                ((23, 24), "##p"),
                ((25, 33), "sentence"),
                ((33, 34), "."),
                ((0, 0), "[SEP]"),
            ]
        )
788

Funtowicz Morgan's avatar
Funtowicz Morgan committed
789
        self.assertEqual([e[1] for e in expected_results], tokenizer_r.convert_ids_to_tokens(tokens["input_ids"]))
Anthony MOI's avatar
Anthony MOI committed
790
        self.assertEqual([e[0] for e in expected_results], tokens["offset_mapping"])
791
792


Funtowicz Morgan's avatar
Funtowicz Morgan committed
793
794
class RobertaFastTokenizerTest(CommonFastTokenizerTest):
    TOKENIZERS_CLASSES = frozenset(
795
796
797
798
799
800
801
802
803
804
        [
            Tokenizer(
                "Roberta",
                RobertaTokenizerFast,
                RobertaTokenizer,
                "vocab_file",
                filter_roberta_detectors,
                (("cls_token", "<s>"),),
            )
        ]
Funtowicz Morgan's avatar
Funtowicz Morgan committed
805
    )
806

Funtowicz Morgan's avatar
Funtowicz Morgan committed
807
808
809
810
    def assert_embeded_special_tokens(self, tokenizer_r, tokenizer_p):
        sentence = "A, <mask> AllenNLP sentence."
        tokens_r = tokenizer_r.encode_plus(sentence, add_special_tokens=True, return_token_type_ids=True)
        tokens_p = tokenizer_p.encode_plus(sentence, add_special_tokens=True, return_token_type_ids=True)
811

Funtowicz Morgan's avatar
Funtowicz Morgan committed
812
        # Rust correctly handles the space before the mask while python doesnt
813
814
        self.assertSequenceEqual(tokens_r["input_ids"], [0, 250, 6, 50264, 3823, 487, 21992, 3645, 4, 2])
        self.assertSequenceEqual(tokens_p["input_ids"], [0, 250, 6, 50264, 3823, 487, 21992, 3645, 4, 2])
815

Funtowicz Morgan's avatar
Funtowicz Morgan committed
816
        # token_type_ids should put 0 everywhere
817
        self.assertEqual(sum(tokens_r["token_type_ids"]), sum(tokens_p["token_type_ids"]))
818

Funtowicz Morgan's avatar
Funtowicz Morgan committed
819
        # attention_mask should put 1 everywhere, so sum over length should be 1
820
        self.assertEqual(
Funtowicz Morgan's avatar
Funtowicz Morgan committed
821
822
823
            sum(tokens_r["attention_mask"]) / len(tokens_r["attention_mask"]),
            sum(tokens_p["attention_mask"]) / len(tokens_p["attention_mask"]),
        )
824

Funtowicz Morgan's avatar
Funtowicz Morgan committed
825
        tokens_r = tokenizer_r.convert_ids_to_tokens(tokens_r["input_ids"])
826
827
828
        tokens_p = tokenizer_p.convert_ids_to_tokens(tokens_p["input_ids"])
        self.assertSequenceEqual(tokens_r, ["<s>", "A", ",", "<mask>", "臓Allen", "N", "LP", "臓sentence", ".", "</s>"])
        self.assertSequenceEqual(tokens_p, ["<s>", "A", ",", "<mask>", "臓Allen", "N", "LP", "臓sentence", ".", "</s>"])
829

830

Funtowicz Morgan's avatar
Funtowicz Morgan committed
831
832
class NoPaddingTokenFastTokenizerMatchingTest(CommonFastTokenizerTest):
    TOKENIZERS_CLASSES = [
833
834
        Tokenizer("OpenAI GPT", OpenAIGPTTokenizerFast, OpenAIGPTTokenizer, "vocab_file", None, None),
        Tokenizer("GPT2", GPT2TokenizerFast, GPT2Tokenizer, "vocab_file", None, [("add_prefix_space", True)]),
Funtowicz Morgan's avatar
Funtowicz Morgan committed
835
    ]
836

837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
    def fast_align_python(self, tokenizer_r, tokenizer_p, tok_case, pretrained_name):
        # Check is_fast is set correctly
        self.assertFalse(tokenizer_p.is_fast)
        self.assertTrue(tokenizer_r.is_fast)

        # Check that Rust and Python align
        self.assert_tokenization_python_rust_equals(tokenizer_r, tokenizer_p)
        self.assert_num_special_tokens_to_add_equal(tokenizer_r, tokenizer_p)
        self.assert_max_length_equal(tokenizer_r, tokenizer_p)
        self.assert_special_tokens_map_equal(tokenizer_r, tokenizer_p)
        self.assert_embeded_special_tokens(tokenizer_r, tokenizer_p)
        self.assert_padding(tokenizer_r, tokenizer_p)

        # Specific for
        kwargs = {}
        if tok_case.kwargs is not None:
            kwargs = dict(tok_case.kwargs)
        tokenizer_r = tok_case.rust_cls.from_pretrained(pretrained_name, **kwargs)
        self.assert_pretokenized_inputs(tokenizer_r, tokenizer_p)

Funtowicz Morgan's avatar
Funtowicz Morgan committed
857
858
859
860
861
862
863
864
865
    def assert_padding(self, tokenizer_r, tokenizer_p, max_length=15):
        # Simple input
        s = "This is a simple input"
        s2 = ["This is a simple input 1", "This is a simple input 2"]
        p = ("This is a simple input", "This is a pair")
        p2 = [
            ("This is a simple input 1", "This is a simple input 2"),
            ("This is a simple pair 1", "This is a simple pair 2"),
        ]
866

Funtowicz Morgan's avatar
Funtowicz Morgan committed
867
        # Simple input tests
868
        self.assertRaises(ValueError, tokenizer_r.encode, s, max_length=max_length, padding="max_length")
869

Funtowicz Morgan's avatar
Funtowicz Morgan committed
870
        # Simple input
871
        self.assertRaises(ValueError, tokenizer_r.encode_plus, s, max_length=max_length, padding="max_length")
872

Funtowicz Morgan's avatar
Funtowicz Morgan committed
873
        # Simple input
874
        self.assertRaises(
Lysandre's avatar
Lysandre committed
875
876
877
878
879
            ValueError,
            tokenizer_r.batch_encode_plus,
            s2,
            max_length=max_length,
            padding="max_length",
880
        )
881

Funtowicz Morgan's avatar
Funtowicz Morgan committed
882
        # Pair input
883
        self.assertRaises(ValueError, tokenizer_r.encode, p, max_length=max_length, padding="max_length")
884

Funtowicz Morgan's avatar
Funtowicz Morgan committed
885
        # Pair input
886
        self.assertRaises(ValueError, tokenizer_r.encode_plus, p, max_length=max_length, padding="max_length")
887

Funtowicz Morgan's avatar
Funtowicz Morgan committed
888
        # Pair input
889
        self.assertRaises(
Lysandre's avatar
Lysandre committed
890
891
892
893
894
            ValueError,
            tokenizer_r.batch_encode_plus,
            p2,
            max_length=max_length,
            padding="max_length",
895
        )