test_tokenization_t5.py 32 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2018 Google T5 Authors and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
15
16
import json
import os
17
import re
18
import tempfile
19
import unittest
thomwolf's avatar
thomwolf committed
20

Lysandre Debut's avatar
Lysandre Debut committed
21
from transformers import SPIECE_UNDERLINE, AddedToken, BatchEncoding, T5Tokenizer, T5TokenizerFast
22
from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_seqio, require_tokenizers, slow
23
from transformers.utils import cached_property, is_tf_available, is_torch_available
thomwolf's avatar
thomwolf committed
24

Yih-Dar's avatar
Yih-Dar committed
25
from ...test_tokenization_common import TokenizerTesterMixin
thomwolf's avatar
thomwolf committed
26

Aymeric Augustin's avatar
Aymeric Augustin committed
27

28
SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
29

30
31
32
33
34
35
if is_torch_available():
    FRAMEWORK = "pt"
elif is_tf_available():
    FRAMEWORK = "tf"
else:
    FRAMEWORK = "jax"
36

thomwolf's avatar
thomwolf committed
37

38
39
@require_sentencepiece
@require_tokenizers
40
class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
41
    from_pretrained_id = "google-t5/t5-small"
thomwolf's avatar
thomwolf committed
42
    tokenizer_class = T5Tokenizer
43
44
    rust_tokenizer_class = T5TokenizerFast
    test_rust_tokenizer = True
45
    test_sentencepiece = True
thomwolf's avatar
thomwolf committed
46
47

    def setUp(self):
Julien Chaumond's avatar
Julien Chaumond committed
48
        super().setUp()
thomwolf's avatar
thomwolf committed
49
50

        # We have a SentencePiece fixture for testing
51
        tokenizer = T5Tokenizer(SAMPLE_VOCAB)
thomwolf's avatar
thomwolf committed
52
53
        tokenizer.save_pretrained(self.tmpdirname)

54
55
56
57
58
59
60
61
62
63
64
65
66
    def test_convert_token_and_id(self):
        """Test ``_convert_token_to_id`` and ``_convert_id_to_token``."""
        token = "<s>"
        token_id = 1

        self.assertEqual(self.get_tokenizer()._convert_token_to_id(token), token_id)
        self.assertEqual(self.get_tokenizer()._convert_id_to_token(token_id), token)

    def test_get_vocab(self):
        vocab_keys = list(self.get_tokenizer().get_vocab().keys())

        self.assertEqual(vocab_keys[0], "<unk>")
        self.assertEqual(vocab_keys[1], "<s>")
67
        self.assertEqual(vocab_keys[1100], "<pad>")
68
69
70
        self.assertEqual(len(vocab_keys), 1_101)

    def test_vocab_size(self):
71
72
        self.assertEqual(self.get_tokenizer().vocab_size, 1000)
        self.assertEqual(len(self.get_tokenizer()), 1101)
73

thomwolf's avatar
thomwolf committed
74
    def test_full_tokenizer(self):
75
        tokenizer = T5Tokenizer(SAMPLE_VOCAB)
thomwolf's avatar
thomwolf committed
76

77
78
        tokens = tokenizer.tokenize("This is a test")
        self.assertListEqual(tokens, ["▁This", "▁is", "▁a", "▁t", "est"])
thomwolf's avatar
thomwolf committed
79

80
        self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [285, 46, 10, 170, 382])
thomwolf's avatar
thomwolf committed
81

82
        tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.")
thomwolf's avatar
thomwolf committed
83
        self.assertListEqual(
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
            tokens,
            [
                SPIECE_UNDERLINE + "I",
                SPIECE_UNDERLINE + "was",
                SPIECE_UNDERLINE + "b",
                "or",
                "n",
                SPIECE_UNDERLINE + "in",
                SPIECE_UNDERLINE + "",
                "9",
                "2",
                "0",
                "0",
                "0",
                ",",
                SPIECE_UNDERLINE + "and",
                SPIECE_UNDERLINE + "this",
                SPIECE_UNDERLINE + "is",
                SPIECE_UNDERLINE + "f",
                "al",
                "s",
                "é",
                ".",
            ],
        )
        ids = tokenizer.convert_tokens_to_ids(tokens)
        self.assertListEqual(ids, [8, 21, 84, 55, 24, 19, 7, 0, 602, 347, 347, 347, 3, 12, 66, 46, 72, 80, 6, 0, 4])
thomwolf's avatar
thomwolf committed
111
112

        back_tokens = tokenizer.convert_ids_to_tokens(ids)
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
        self.assertListEqual(
            back_tokens,
            [
                SPIECE_UNDERLINE + "I",
                SPIECE_UNDERLINE + "was",
                SPIECE_UNDERLINE + "b",
                "or",
                "n",
                SPIECE_UNDERLINE + "in",
                SPIECE_UNDERLINE + "",
                "<unk>",
                "2",
                "0",
                "0",
                "0",
                ",",
                SPIECE_UNDERLINE + "and",
                SPIECE_UNDERLINE + "this",
                SPIECE_UNDERLINE + "is",
                SPIECE_UNDERLINE + "f",
                "al",
                "s",
                "<unk>",
                ".",
            ],
        )
139

140
141
    @cached_property
    def t5_base_tokenizer(self):
142
        return T5Tokenizer.from_pretrained("google-t5/t5-base")
143

144
145
    @cached_property
    def t5_base_tokenizer_fast(self):
146
        return T5TokenizerFast.from_pretrained("google-t5/t5-base")
147
148

    def get_tokenizer(self, **kwargs) -> T5Tokenizer:
149
        return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
150
151

    def get_rust_tokenizer(self, **kwargs) -> T5TokenizerFast:
152
        return self.rust_tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
153
154
155

    def test_rust_and_python_full_tokenizers(self):
        if not self.test_rust_tokenizer:
amyeroberts's avatar
amyeroberts committed
156
            self.skipTest(reason="test_rust_tokenizer is set to False")
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175

        tokenizer = self.get_tokenizer()
        rust_tokenizer = self.get_rust_tokenizer()

        sequence = "I was born in 92000, and this is falsé."

        tokens = tokenizer.tokenize(sequence)
        rust_tokens = rust_tokenizer.tokenize(sequence)
        self.assertListEqual(tokens, rust_tokens)

        ids = tokenizer.encode(sequence, add_special_tokens=False)
        rust_ids = rust_tokenizer.encode(sequence, add_special_tokens=False)
        self.assertListEqual(ids, rust_ids)

        rust_tokenizer = self.get_rust_tokenizer()
        ids = tokenizer.encode(sequence)
        rust_ids = rust_tokenizer.encode(sequence)
        self.assertListEqual(ids, rust_ids)

176
177
178
179
180
181
    def test_eos_treatment(self):
        tokenizer = self.t5_base_tokenizer
        batch_with_eos_added = tokenizer(["hi</s>", "I went to the gym</s>", "</s>"])
        batch_without_eos_added = tokenizer(["hi", "I went to the gym", ""])
        self.assertListEqual(batch_with_eos_added["input_ids"], batch_without_eos_added["input_ids"])

182
    def test_prepare_batch(self):
183
        tokenizer = self.t5_base_tokenizer
184
185
        src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."]
        expected_src_tokens = [71, 307, 8986, 21, 4505, 1635, 1707, 5, tokenizer.eos_token_id]
186
        batch = tokenizer(src_text, padding=True, return_tensors=FRAMEWORK)
187
        self.assertIsInstance(batch, BatchEncoding)
188
189
190
191
192
193

        if FRAMEWORK != "jax":
            result = list(batch.input_ids.numpy()[0])
        else:
            result = list(batch.input_ids.tolist()[0])

194
        self.assertListEqual(expected_src_tokens, result)
195

196
197
        self.assertEqual((2, 9), batch.input_ids.shape)
        self.assertEqual((2, 9), batch.attention_mask.shape)
198

199
    def test_empty_target_text(self):
200
        tokenizer = self.t5_base_tokenizer
201
        src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."]
202
        batch = tokenizer(src_text, padding=True, return_tensors=FRAMEWORK)
203
204
205
206
207
208
        # check if input_ids are returned and no decoder_input_ids
        self.assertIn("input_ids", batch)
        self.assertIn("attention_mask", batch)
        self.assertNotIn("decoder_input_ids", batch)
        self.assertNotIn("decoder_attention_mask", batch)

209
    def test_max_length(self):
210
        tokenizer = self.t5_base_tokenizer
211
212
213
214
        tgt_text = [
            "Summary of the text.",
            "Another summary.",
        ]
215
216
217
        targets = tokenizer(
            text_target=tgt_text, max_length=32, padding="max_length", truncation=True, return_tensors=FRAMEWORK
        )
218
        self.assertEqual(32, targets["input_ids"].shape[1])
219
220

    def test_outputs_not_longer_than_maxlen(self):
221
        tokenizer = self.t5_base_tokenizer
222

223
224
        batch = tokenizer(
            ["I am a small frog" * 1000, "I am a small frog"], padding=True, truncation=True, return_tensors=FRAMEWORK
225
226
        )
        self.assertIsInstance(batch, BatchEncoding)
227
228
229
        # Since T5 does NOT have a max input length,
        # this test should be changed to the following in Transformers v5:
        # self.assertEqual(batch.input_ids.shape, (2, 8001))
230
        self.assertEqual(batch.input_ids.shape, (2, 8001))
231
232

    def test_eos_in_input(self):
233
        tokenizer = self.t5_base_tokenizer
234
        src_text = ["A long paragraph for summarization. </s>"]
235
        tgt_text = ["Summary of the text. </s>"]
236
        expected_src_tokens = [71, 307, 8986, 21, 4505, 1635, 1707, 5, 1]
237
        expected_tgt_tokens = [20698, 13, 8, 1499, 5, 1]
238

239
        batch = tokenizer(src_text, text_target=tgt_text)
240

241
        self.assertEqual(expected_src_tokens, batch["input_ids"][0])
242
        self.assertEqual(expected_tgt_tokens, batch["labels"][0])
243

244
245
246
247
248
249
250
251
252
253
254
255
256
257
    def test_token_type_ids(self):
        src_text_1 = ["A first paragraph for summarization."]
        src_text_2 = ["A second paragraph for summarization."]

        fast_token_type_ids = self.t5_base_tokenizer_fast(
            src_text_1, src_text_2, add_special_tokens=True, return_token_type_ids=True
        ).token_type_ids
        slow_token_type_ids = self.t5_base_tokenizer(
            src_text_1, src_text_2, add_special_tokens=True, return_token_type_ids=True
        ).token_type_ids

        self.assertEqual(slow_token_type_ids, fast_token_type_ids)
        self.assertEqual(len(slow_token_type_ids[0]), 18)

258
259
260
261
262
263
264
265
266
267
268
269
270
271
    def test_fast_and_slow_same_result(self):
        src_text = "<pad> Today is <unk> nice day </s>"
        tgt_ids = [0, 1960, 19, 2, 1245, 239, 1]
        tgt_text = "<pad> Today is<unk> nice day</s>"

        fast_ids = self.t5_base_tokenizer_fast(src_text, add_special_tokens=False).input_ids
        slow_ids = self.t5_base_tokenizer(src_text, add_special_tokens=False).input_ids
        self.assertEqual(tgt_ids, fast_ids)
        self.assertEqual(tgt_ids, slow_ids)

        fast_text = self.t5_base_tokenizer_fast.decode(fast_ids)
        slow_text = self.t5_base_tokenizer.decode(fast_ids)
        self.assertEqual(tgt_text, fast_text)
        self.assertEqual(tgt_text, slow_text)
Lysandre Debut's avatar
Lysandre Debut committed
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298

    def test_special_tokens_initialization(self):
        for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
            with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
                added_tokens = [f"<extra_id_{i}>" for i in range(100)] + [AddedToken("<special>", lstrip=True)]

                tokenizer_r = self.rust_tokenizer_class.from_pretrained(
                    pretrained_name, additional_special_tokens=added_tokens, **kwargs
                )
                tokenizer_cr = self.rust_tokenizer_class.from_pretrained(
                    pretrained_name, additional_special_tokens=added_tokens, **kwargs, from_slow=True
                )
                tokenizer_p = self.tokenizer_class.from_pretrained(
                    pretrained_name, additional_special_tokens=added_tokens, **kwargs
                )

                p_output = tokenizer_p.encode("Hey this is a <special> token")
                r_output = tokenizer_r.encode("Hey this is a <special> token")
                cr_output = tokenizer_cr.encode("Hey this is a <special> token")

                special_token_id = tokenizer_r.encode("<special>", add_special_tokens=False)[0]

                self.assertEqual(p_output, r_output)
                self.assertEqual(cr_output, r_output)
                self.assertTrue(special_token_id in p_output)
                self.assertTrue(special_token_id in r_output)
                self.assertTrue(special_token_id in cr_output)
299

300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
    def test_special_tokens_initialization_with_non_empty_additional_special_tokens(self):
        tokenizer_list = []
        if self.test_slow_tokenizer:
            tokenizer_list.append((self.tokenizer_class, self.get_tokenizer()))

        if self.test_rust_tokenizer:
            tokenizer_list.append((self.rust_tokenizer_class, self.get_rust_tokenizer()))

        for tokenizer_class, tokenizer_utils in tokenizer_list:
            with tempfile.TemporaryDirectory() as tmp_dir:
                tokenizer_utils.save_pretrained(tmp_dir)

                with open(os.path.join(tmp_dir, "special_tokens_map.json"), encoding="utf-8") as json_file:
                    special_tokens_map = json.load(json_file)

                with open(os.path.join(tmp_dir, "tokenizer_config.json"), encoding="utf-8") as json_file:
                    tokenizer_config = json.load(json_file)

                added_tokens_extra_ids = [f"<extra_id_{i}>" for i in range(100)]

                special_tokens_map["additional_special_tokens"] = added_tokens_extra_ids + [
                    "an_additional_special_token"
                ]
                tokenizer_config["additional_special_tokens"] = added_tokens_extra_ids + [
                    "an_additional_special_token"
                ]

                with open(os.path.join(tmp_dir, "special_tokens_map.json"), "w", encoding="utf-8") as outfile:
                    json.dump(special_tokens_map, outfile)
                with open(os.path.join(tmp_dir, "tokenizer_config.json"), "w", encoding="utf-8") as outfile:
                    json.dump(tokenizer_config, outfile)

                # the following checks allow us to verify that our test works as expected, i.e. that the tokenizer takes
                # into account the new value of additional_special_tokens given in the "tokenizer_config.json" and
                # "special_tokens_map.json" files
                tokenizer_without_change_in_init = tokenizer_class.from_pretrained(
                    tmp_dir,
                )
                self.assertIn(
                    "an_additional_special_token", tokenizer_without_change_in_init.additional_special_tokens
                )
                # self.assertIn("an_additional_special_token",tokenizer_without_change_in_init.get_vocab()) # ByT5Tokenization no vocab
                self.assertEqual(
                    ["an_additional_special_token"],
                    tokenizer_without_change_in_init.convert_ids_to_tokens(
                        tokenizer_without_change_in_init.convert_tokens_to_ids(["an_additional_special_token"])
                    ),
                )

                # Now we test that we can change the value of additional_special_tokens in the from_pretrained
                new_added_tokens = added_tokens_extra_ids + [AddedToken("a_new_additional_special_token", lstrip=True)]
                tokenizer = tokenizer_class.from_pretrained(
                    tmp_dir,
                    additional_special_tokens=new_added_tokens,
                )

                self.assertIn("a_new_additional_special_token", tokenizer.additional_special_tokens)
                self.assertEqual(
                    ["a_new_additional_special_token"],
                    tokenizer.convert_ids_to_tokens(
                        tokenizer.convert_tokens_to_ids(["a_new_additional_special_token"])
                    ),
                )

364
    # overwritten from `test_tokenization_common` since T5 has no max length
365
366
    @slow
    def test_tokenizer_integration(self):
367
        expected_encoding = {'input_ids': [[31220, 7, 41, 14034, 801, 38, 3, 102, 63, 17, 127, 524, 18, 7031, 2032, 277, 11, 3, 102, 63, 17, 127, 524, 18, 2026, 17, 10761, 18, 7041, 61, 795, 879, 18, 19681, 4648, 7, 41, 12920, 382, 6, 350, 6383, 4949, 6, 2158, 12920, 382, 9, 6, 3, 4, 11160, 6, 2043, 17153, 279, 49, 17, 6, 3, 4, 434, 9688, 11439, 21, 6869, 10509, 17725, 41, 567, 9138, 61, 11, 6869, 10509, 11946, 41, 18207, 517, 61, 28, 147, 3538, 1220, 7140, 10761, 2250, 16, 910, 1220, 8024, 11, 1659, 1413, 32, 883, 2020, 344, 2215, 226, 6, 12901, 382, 127, 524, 11, 4738, 7, 127, 15390, 5, 1], [272, 24203, 19, 876, 12, 554, 18, 9719, 1659, 2647, 26352, 6497, 7, 45, 73, 9339, 400, 26, 1499, 57, 22801, 10760, 30, 321, 646, 11, 269, 2625, 16, 66, 7500, 5, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [37, 1704, 4216, 3, 20400, 4418, 7, 147, 8, 19743, 1782, 5, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]}  # fmt: skip
368
369
370

        self.tokenizer_integration_test_util(
            expected_encoding=expected_encoding,
371
            model_name="google-t5/t5-base",
372
373
            revision="5a7ff2d8f5117c194c7e32ec1ccbf04642cca99b",
        )
374
375
376
377

    def test_get_sentinel_tokens(self):
        tokenizer = T5Tokenizer(SAMPLE_VOCAB, extra_ids=10)
        sentinel_tokens = tokenizer.get_sentinel_tokens()
378
        self.assertEqual(len(sentinel_tokens), 10)
379
        self.assertListEqual(sorted(sentinel_tokens), sorted([f"<extra_id_{str(i)}>" for i in range(0, 10)]))
380
        self.assertTrue([re.search(r"<extra_id_\d+>", token) is not None for token in sentinel_tokens])
381
382
383

    def test_get_sentinel_token_ids(self):
        tokenizer = T5Tokenizer(SAMPLE_VOCAB, extra_ids=10)
384
        self.assertListEqual(sorted(tokenizer.get_sentinel_token_ids()), sorted(range(1000, 1010)))
385
386
387
388

    def test_get_sentinel_tokens_for_fasttokenizer(self):
        tokenizer = T5TokenizerFast(SAMPLE_VOCAB, extra_ids=10)
        sentinel_tokens = tokenizer.get_sentinel_tokens()
389
        self.assertEqual(len(sentinel_tokens), 10)
390
        self.assertListEqual(sorted(sentinel_tokens), sorted([f"<extra_id_{str(i)}>" for i in range(0, 10)]))
391
        self.assertTrue([re.search(r"<extra_id_\d+>", token) is not None for token in sentinel_tokens])
392
393
394

    def test_get_sentinel_token_ids_for_fasttokenizer(self):
        tokenizer = T5TokenizerFast(SAMPLE_VOCAB, extra_ids=10)
395
        self.assertListEqual(sorted(tokenizer.get_sentinel_token_ids()), sorted(range(1000, 1010)))
396

397
    def test_some_edge_cases(self):
398
        tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-base", legacy=False)
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421

        sp_tokens = tokenizer.sp_model.encode("</s>>", out_type=str)
        self.assertEqual(sp_tokens, ["<", "/", "s", ">", ">"])
        tokens = tokenizer.tokenize("</s>>")
        self.assertNotEqual(sp_tokens, tokens)
        self.assertEqual(tokens, ["</s>", ">"])

        tokens = tokenizer.tokenize("")
        self.assertEqual(tokens, [])
        self.assertEqual(tokens, tokenizer.sp_model.encode("", out_type=str))

        tokens = tokenizer.tokenize(" ")
        self.assertEqual(tokens, [])
        self.assertEqual(tokens, tokenizer.sp_model.encode(" ", out_type=str))

        tokens = tokenizer.tokenize("▁")
        self.assertEqual(tokens, [])
        self.assertEqual(tokens, tokenizer.sp_model.encode("▁", out_type=str))

        tokens = tokenizer.tokenize(" ▁")
        self.assertEqual(tokens, [])
        self.assertEqual(tokens, tokenizer.sp_model.encode("▁", out_type=str))

422
423
    def test_fast_slow_edge_cases(self):
        # We are testing spaces before and spaces after special tokens + space transformations
424
425
        slow_tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-base", legacy=False)
        fast_tokenizer = T5TokenizerFast.from_pretrained("google-t5/t5-base", legacy=False, from_slow=True)
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
        slow_tokenizer.add_tokens(AddedToken("<new_token_test_>", rstrip=False, lstrip=False, normalized=False))
        fast_tokenizer.add_tokens(AddedToken("<new_token_test_>", rstrip=False, lstrip=False, normalized=False))

        edge_case = "Hey!<new_token_test_>. How</s>Hey <new_token_test_>!"
        EXPECTED_SLOW = ["▁Hey", "!", "<new_token_test_>", ".", "▁How", "</s>", "He", "y", "<new_token_test_>", "!"]  # fmt: skip
        with self.subTest(f"slow {edge_case} normalized = False"):
            self.assertEqual(slow_tokenizer.tokenize(edge_case), EXPECTED_SLOW)
        with self.subTest(f"Fast {edge_case} normalized = False"):
            self.assertEqual(fast_tokenizer.tokenize(edge_case), EXPECTED_SLOW)

        hard_case = "Hey! <new_token_test_>. How</s>   Hey   <new_token_test_>  !     .     "
        EXPECTED_SLOW = ["▁Hey", "!", "<new_token_test_>", ".", "▁How", "</s>", "▁Hey", "<new_token_test_>", "▁", "!", "▁", "."]  # fmt: skip
        with self.subTest(f"slow {edge_case} normalized = False"):
            self.assertEqual(slow_tokenizer.tokenize(hard_case), EXPECTED_SLOW)
        with self.subTest(f"fast {edge_case} normalized = False"):
            self.assertEqual(fast_tokenizer.tokenize(hard_case), EXPECTED_SLOW)

443
        fast_tokenizer = T5TokenizerFast.from_pretrained("google-t5/t5-base", legacy=False, from_slow=True)
444
445
446
447
448
449
450
451
452
453
454
455
456
        fast_tokenizer.add_tokens(AddedToken("<new_token_test_>", rstrip=False, lstrip=False, normalized=True))

        # `normalized=True` is the default normalization scheme when adding a token. Normalize -> don't strip the space.
        # the issue now is that our slow tokenizer should NOT strip the space if we want to simulate sentencepiece token addition.

        EXPECTED_FAST = ["▁Hey", "!", "<new_token_test_>", ".", "▁How", "</s>", "He", "y", "▁", "<new_token_test_>", "!"]  # fmt: skip
        with self.subTest(f"fast {edge_case} normalized = True"):
            self.assertEqual(fast_tokenizer.tokenize(edge_case), EXPECTED_FAST)

        EXPECTED_FAST = ['▁Hey', '!', '▁', '<new_token_test_>', '.', '▁How', '</s>', '▁Hey','▁', '<new_token_test_>', '▁', '!', '▁', '.']  # fmt: skip
        with self.subTest(f"fast {edge_case} normalized = False"):
            self.assertEqual(fast_tokenizer.tokenize(hard_case), EXPECTED_FAST)

457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
    def test_add_prefix_space(self):
        pretrained_name = "google-t5/t5-base"
        inputs = "Hey how are you doing"
        EXPECTED_WITH_SPACE = [9459, 149, 33, 25, 692, 1]
        EXPECTED_WO_SPACE = [3845, 63, 149, 33, 25, 692, 1]

        slow_ = self.tokenizer_class.from_pretrained(pretrained_name, add_prefix_space=False, legacy=False)
        fast_ = self.rust_tokenizer_class.from_pretrained(
            pretrained_name, add_prefix_space=False, legacy=False, from_slow=True
        )
        self.assertEqual(slow_.encode(inputs), EXPECTED_WO_SPACE)
        self.assertEqual(slow_.encode(inputs), fast_.encode(inputs))
        self.assertEqual(slow_.tokenize(inputs), ["He", "y", "▁how", "▁are", "▁you", "▁doing"])
        self.assertEqual(slow_.decode(EXPECTED_WO_SPACE, skip_special_tokens=True), inputs)
        self.assertEqual(
            slow_.decode(EXPECTED_WO_SPACE, skip_special_tokens=True),
            fast_.decode(EXPECTED_WO_SPACE, skip_special_tokens=True),
        )

        slow_ = self.tokenizer_class.from_pretrained(pretrained_name, add_prefix_space=True, legacy=False)
        fast_ = self.rust_tokenizer_class.from_pretrained(pretrained_name, add_prefix_space=True, legacy=False)
        self.assertEqual(slow_.encode(inputs), EXPECTED_WITH_SPACE)
        self.assertEqual(slow_.encode(inputs), fast_.encode(inputs))
        self.assertEqual(slow_.tokenize(inputs), ["▁Hey", "▁how", "▁are", "▁you", "▁doing"])
        self.assertEqual(slow_.decode(EXPECTED_WITH_SPACE, skip_special_tokens=True), inputs)
        self.assertEqual(
            slow_.decode(EXPECTED_WITH_SPACE, skip_special_tokens=True),
            fast_.decode(EXPECTED_WITH_SPACE, skip_special_tokens=True),
        )

487
488
489
490
491
492
493
494
495
496

@require_sentencepiece
@require_tokenizers
class CommonSpmIntegrationTests(unittest.TestCase):
    """
    A class that regroups important test to make sure that we properly handle the special tokens.
    """

    @classmethod
    def setUpClass(cls):
497
498
499
500
501
        tokenizer = T5Tokenizer(SAMPLE_VOCAB, extra_ids=0, legacy=False)
        tokenizer.add_special_tokens(
            {"additional_special_tokens": [AddedToken("<extra_id_0>", rstrip=False, lstrip=False)]}
        )
        # TODO ArthurZ the above is necessary as addedTokens / intialization sucks. Trie is not correctly created
502
        # So the extra ids are split....
503
504
505
506
507
508
509
510
        cls.tokenizer = tokenizer

    def test_add_dummy_prefix(self):
        # make sure `'▁'` is prepended, and outputs match sp_model's
        # `sentencepiece.NormalizerSpec.add_dummy_prefix` attribute
        input_ids = self.tokenizer.encode(". Hello", add_special_tokens=False)
        self.assertEqual(input_ids, [7, 4, 156, 86, 20])
        sp_encode = self.tokenizer.sp_model.encode(". Hello")
511
        self.assertEqual(input_ids, [7] + sp_encode)
512
513
514
        tokens = self.tokenizer.tokenize(". Hello")
        self.assertEqual(tokens, ["▁", ".", "▁He", "ll", "o"])

515
516
517
518
519
520
521
522
523
524
525
526
        tokens = self.tokenizer.tokenize("")
        self.assertEqual(tokens, [])
        self.assertEqual(tokens, self.tokenizer.sp_model.encode("", out_type=str))

        tokens = self.tokenizer.tokenize(" ")
        self.assertEqual(tokens, [])
        self.assertEqual(tokens, self.tokenizer.sp_model.encode(" ", out_type=str))

        tokens = self.tokenizer.tokenize("▁")
        self.assertEqual(tokens, [])
        self.assertEqual(tokens, self.tokenizer.sp_model.encode("▁", out_type=str))

527
528
529
530
531
532
    def test_remove_extra_whitespaces(self):
        # make sure the extra spaces are eaten
        # sentencepiece.NormalizerSpec.remove_extra_whitespaces attribute
        input_ids = self.tokenizer.encode("       . Hello", add_special_tokens=False)
        self.assertEqual(input_ids, [7, 4, 156, 86, 20])
        sp_encode = self.tokenizer.sp_model.encode("       . Hello")
533
        self.assertEqual(input_ids, [7] + sp_encode)
534
535
536
537
538
539
540
541
542
543
        tokens = self.tokenizer.tokenize(" . Hello")
        self.assertEqual(tokens, ["▁", ".", "▁He", "ll", "o"])

        # `'▁'` is also a whitespace
        input_ids = self.tokenizer.encode("▁He is not")
        self.assertEqual(input_ids, [156, 46, 44, 2])
        tokens = self.tokenizer.tokenize("▁He is not")
        self.assertEqual(tokens, ["▁He", "▁is", "▁not"])  # no extra space added

        input_ids = self.tokenizer.encode("▁He is not<extra_id_0>             ▁He")
544
545
        # here t5x does not eat with lstrip, so there is and extra ▁He in the original one
        self.assertEqual(input_ids, [156, 46, 44, 1001, 156, 2])
546
        tokens = self.tokenizer.tokenize("▁He is not<extra_id_0>              ▁He")
547
        self.assertEqual(tokens, ["▁He", "▁is", "▁not", "<extra_id_0>", "▁He"])  # spaces are eaten by spm
548
549
550
551
552
553
554
555
556
557
558
        # make sure that the output after the extra id is the same as if
        # extra_id was not there
        input_ids = self.tokenizer.encode("▁He is not             ▁He")
        self.assertEqual(input_ids, [156, 46, 44, 156, 2])
        tokens = self.tokenizer.tokenize("▁He is not              ▁He")
        self.assertEqual(tokens, ["▁He", "▁is", "▁not", "▁He"])  # spaces are eaten by spm even if not start

    def test_character_after_special_token(self):
        # Make sure that `tokenizer.tokenize` is similar to
        # adding the equivalent special token to the vocab
        input_ids = self.tokenizer.encode("Hey <extra_id_0>I")
559
        self.assertEqual(input_ids, [156, 30, 1001, 100, 2])
560
561
562
563
        tokens = self.tokenizer.tokenize("Hey <extra_id_0>I")
        self.assertEqual(tokens, ["▁He", "y", "<extra_id_0>", "I"])

        input_ids = self.tokenizer.encode("Hello, <extra_id_0>,")
564
        self.assertEqual(input_ids, [156, 86, 20, 3, 1001, 3, 2])
565
566
567
568
569
        tokens = self.tokenizer.tokenize("Hello, <extra_id_0>,")
        self.assertEqual(tokens, ["▁He", "ll", "o", ",", "<extra_id_0>", ","])

    def test_special_tokens_strip(self):
        input_ids = self.tokenizer.encode(" <extra_id_0> ,")
570
        self.assertEqual(input_ids, [1001, 7, 3, 2])
571
        tokens = self.tokenizer.tokenize(" <extra_id_0> ,")
572
573
        # spaces are not longer eaten by rstrip and lstrip
        self.assertEqual(tokens, ["<extra_id_0>", "▁", ","])
574
575
576

        # test with a begin of word like `▁He`
        input_ids = self.tokenizer.encode("No <extra_id_0> He")
577
        self.assertEqual(input_ids, [284, 1001, 156, 2])
578
579
        # spaces are eaten by rstrip / lstrip, so this is expected. Don't strip otherwise you break
        tokens = self.tokenizer.tokenize("No <extra_id_0> He")
580
        self.assertEqual(tokens, ["▁No", "<extra_id_0>", "▁He"])
581
582
583
584
585

        # Make sure this does not happen if we don't strip
        tokenizer = T5Tokenizer(SAMPLE_VOCAB, extra_ids=0)
        tokenizer.add_special_tokens({"bos_token": AddedToken("<bos>")})
        input_ids = tokenizer.encode("No <bos> He")
586
        self.assertEqual(input_ids, [284, 1001, 156, 2])
587
588
589
590
591
592
593
594
595
596
597
598
599
600
        tokens = tokenizer.tokenize("No <bos> He")
        # the first `' '` after `'No'` is eaten by spm:
        self.assertEqual(tokenizer.sp_model.encode("No         ", out_type=str), ["▁No"])
        self.assertEqual(tokens, ["▁No", "<bos>", "▁He"])

    @require_seqio
    @unittest.skipIf(
        os.getenv("RUN_TOKENIZER_INTEGRATION", "0") == "0",
        "RUN_TOKENIZER_INTEGRATION=1 to run tokenizer integration tests",
    )
    def test_integration_seqio(self):
        from datasets import load_dataset
        from seqio import SentencePieceVocabulary

601
        ds = load_dataset("facebook/xnli", "all_languages", split="train+test+validation")
602

603
        # TODO @ArthurZucker fix the 3 commented tests with #23909
604
605
606
607
608
609
610
611
        input_texts = [
            "Bonjour <extra_id_0>.",
            # "Bonjour<extra_id_0>.",  # this will fail. In T5 the special token has to be at the end.
            # because in T5 they add `_<extra_id_0>` to the vocab, not `<extra_id_0>`.
            "                   Hey <extra_id_0>I love you",
            # "Hey <extra_id_0> I love you", # this will fail, we strip left, to _I vs I
            # "Hey <extra_id_0>▁He", # this will fail for the same reason, we replace `_` then strip
        ]
612

613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
        import tqdm

        # Test with umt5
        vocab_path = "gs://t5-data/vocabs/umt5.256000/sentencepiece.model"
        t5x_tokenizer = SentencePieceVocabulary(vocab_path, extra_ids=300)
        hf_tokenizer = T5Tokenizer.from_pretrained("google/umt5-small", legacy=False)
        for text in input_texts:
            self.assertEqual(
                hf_tokenizer.encode(text, add_special_tokens=False), t5x_tokenizer.tokenizer.tokenize(text), f"{text}"
            )
        for texts in tqdm.tqdm(ds["premise"]):
            for text in texts:
                self.assertEqual(
                    hf_tokenizer.encode(text, add_special_tokens=False),
                    t5x_tokenizer.tokenizer.tokenize(text),
                    f"{text}",
                )

        # Test with T5
632
        hf_tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small")
633
634
635
636
637
638
639
640
641
642
643
644
645
        vocab_path = "gs://t5-data/vocabs/cc_all.32000/sentencepiece.model"
        t5x_tokenizer = SentencePieceVocabulary(vocab_path, extra_ids=300)
        for text in input_texts:
            self.assertEqual(
                hf_tokenizer.encode(text, add_special_tokens=False), t5x_tokenizer.tokenizer.tokenize(text), f"{text}"
            )
        for texts in tqdm.tqdm(ds["premise"]):
            for text in texts:
                self.assertEqual(
                    hf_tokenizer.encode(text, add_special_tokens=False),
                    t5x_tokenizer.tokenizer.tokenize(text),
                    f"{text}",
                )