test_tokenization_utils.py 12.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
# coding=utf-8
# Copyright 2019 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys
import tempfile
import unittest
import unittest.mock as mock
from pathlib import Path

from huggingface_hub import HfFolder, delete_repo
from huggingface_hub.file_download import http_get
from requests.exceptions import HTTPError

from transformers import (
    AlbertTokenizer,
    AutoTokenizer,
    BertTokenizer,
    BertTokenizerFast,
    GPT2TokenizerFast,
    is_tokenizers_available,
)
from transformers.testing_utils import TOKEN, USER, is_staging_test, require_tokenizers
Ahmed Moubtahij's avatar
Ahmed Moubtahij committed
36
from transformers.tokenization_utils import ExtensionsTrie, Trie
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75


sys.path.append(str(Path(__file__).parent.parent / "utils"))

from test_module.custom_tokenization import CustomTokenizer  # noqa E402


if is_tokenizers_available():
    from test_module.custom_tokenization_fast import CustomTokenizerFast


class TokenizerUtilTester(unittest.TestCase):
    def test_cached_files_are_used_when_internet_is_down(self):
        # A mock response for an HTTP head request to emulate server down
        response_mock = mock.Mock()
        response_mock.status_code = 500
        response_mock.headers = {}
        response_mock.raise_for_status.side_effect = HTTPError
        response_mock.json.return_value = {}

        # Download this model to make sure it's in the cache.
        _ = BertTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")

        # Under the mock environment we get a 500 error when trying to reach the tokenizer.
        with mock.patch("requests.Session.request", return_value=response_mock) as mock_head:
            _ = BertTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
            # This check we did call the fake head request
            mock_head.assert_called()

    @require_tokenizers
    def test_cached_files_are_used_when_internet_is_down_missing_files(self):
        # A mock response for an HTTP head request to emulate server down
        response_mock = mock.Mock()
        response_mock.status_code = 500
        response_mock.headers = {}
        response_mock.raise_for_status.side_effect = HTTPError
        response_mock.json.return_value = {}

        # Download this model to make sure it's in the cache.
76
        _ = GPT2TokenizerFast.from_pretrained("openai-community/gpt2")
77
78
79

        # Under the mock environment we get a 500 error when trying to reach the tokenizer.
        with mock.patch("requests.Session.request", return_value=response_mock) as mock_head:
80
            _ = GPT2TokenizerFast.from_pretrained("openai-community/gpt2")
81
82
83
84
85
86
87
88
            # This check we did call the fake head request
            mock_head.assert_called()

    def test_legacy_load_from_one_file(self):
        # This test is for deprecated behavior and can be removed in v5
        try:
            tmp_file = tempfile.mktemp()
            with open(tmp_file, "wb") as f:
89
                http_get("https://huggingface.co/albert/albert-base-v1/resolve/main/spiece.model", f)
90
91
92
93
94
95
96
97
98

            _ = AlbertTokenizer.from_pretrained(tmp_file)
        finally:
            os.remove(tmp_file)

        # Supporting this legacy load introduced a weird bug where the tokenizer would load local files if they are in
        # the current folder and have the right name.
        if os.path.isfile("tokenizer.json"):
            # We skip the test if the user has a `tokenizer.json` in this folder to avoid deleting it.
amyeroberts's avatar
amyeroberts committed
99
            self.skipTest(reason="Skipping test as there is a `tokenizer.json` file in the current folder.")
100
101
102
103
        try:
            with open("tokenizer.json", "wb") as f:
                http_get("https://huggingface.co/hf-internal-testing/tiny-random-bert/blob/main/tokenizer.json", f)
            tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
104
            # The tiny random BERT has a vocab size of 1024, tiny openai-community/gpt2 as a vocab size of 1000
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
            self.assertEqual(tokenizer.vocab_size, 1000)
            # Tokenizer should depend on the remote checkpoint, not the local tokenizer.json file.

        finally:
            os.remove("tokenizer.json")


@is_staging_test
class TokenizerPushToHubTester(unittest.TestCase):
    vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]", "bla", "blou"]

    @classmethod
    def setUpClass(cls):
        cls._token = TOKEN
        HfFolder.save_token(TOKEN)

    @classmethod
    def tearDownClass(cls):
        try:
            delete_repo(token=cls._token, repo_id="test-tokenizer")
        except HTTPError:
            pass

        try:
            delete_repo(token=cls._token, repo_id="valid_org/test-tokenizer-org")
        except HTTPError:
            pass

        try:
            delete_repo(token=cls._token, repo_id="test-dynamic-tokenizer")
        except HTTPError:
            pass

    def test_push_to_hub(self):
        with tempfile.TemporaryDirectory() as tmp_dir:
            vocab_file = os.path.join(tmp_dir, "vocab.txt")
            with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
                vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
            tokenizer = BertTokenizer(vocab_file)

Arthur's avatar
Arthur committed
145
        tokenizer.push_to_hub("test-tokenizer", token=self._token)
146
147
148
        new_tokenizer = BertTokenizer.from_pretrained(f"{USER}/test-tokenizer")
        self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)

Yih-Dar's avatar
Yih-Dar committed
149
150
151
152
153
        try:
            # Reset repo
            delete_repo(token=self._token, repo_id="test-tokenizer")
        except:  # noqa E722
            pass
154
155
156

        # Push to hub via save_pretrained
        with tempfile.TemporaryDirectory() as tmp_dir:
Arthur's avatar
Arthur committed
157
            tokenizer.save_pretrained(tmp_dir, repo_id="test-tokenizer", push_to_hub=True, token=self._token)
158
159
160
161
162
163
164
165
166
167
168

        new_tokenizer = BertTokenizer.from_pretrained(f"{USER}/test-tokenizer")
        self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)

    def test_push_to_hub_in_organization(self):
        with tempfile.TemporaryDirectory() as tmp_dir:
            vocab_file = os.path.join(tmp_dir, "vocab.txt")
            with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
                vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
            tokenizer = BertTokenizer(vocab_file)

Arthur's avatar
Arthur committed
169
        tokenizer.push_to_hub("valid_org/test-tokenizer-org", token=self._token)
170
171
172
        new_tokenizer = BertTokenizer.from_pretrained("valid_org/test-tokenizer-org")
        self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)

Yih-Dar's avatar
Yih-Dar committed
173
174
175
176
177
        try:
            # Reset repo
            delete_repo(token=self._token, repo_id="valid_org/test-tokenizer-org")
        except:  # noqa E722
            pass
178
179
180
181

        # Push to hub via save_pretrained
        with tempfile.TemporaryDirectory() as tmp_dir:
            tokenizer.save_pretrained(
Arthur's avatar
Arthur committed
182
                tmp_dir, repo_id="valid_org/test-tokenizer-org", push_to_hub=True, token=self._token
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
            )

        new_tokenizer = BertTokenizer.from_pretrained("valid_org/test-tokenizer-org")
        self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)

    @require_tokenizers
    def test_push_to_hub_dynamic_tokenizer(self):
        CustomTokenizer.register_for_auto_class()
        with tempfile.TemporaryDirectory() as tmp_dir:
            vocab_file = os.path.join(tmp_dir, "vocab.txt")
            with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
                vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
            tokenizer = CustomTokenizer(vocab_file)

        # No fast custom tokenizer
Arthur's avatar
Arthur committed
198
        tokenizer.push_to_hub("test-dynamic-tokenizer", token=self._token)
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214

        tokenizer = AutoTokenizer.from_pretrained(f"{USER}/test-dynamic-tokenizer", trust_remote_code=True)
        # Can't make an isinstance check because the new_model.config is from the CustomTokenizer class of a dynamic module
        self.assertEqual(tokenizer.__class__.__name__, "CustomTokenizer")

        # Fast and slow custom tokenizer
        CustomTokenizerFast.register_for_auto_class()
        with tempfile.TemporaryDirectory() as tmp_dir:
            vocab_file = os.path.join(tmp_dir, "vocab.txt")
            with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
                vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))

            bert_tokenizer = BertTokenizerFast.from_pretrained(tmp_dir)
            bert_tokenizer.save_pretrained(tmp_dir)
            tokenizer = CustomTokenizerFast.from_pretrained(tmp_dir)

Arthur's avatar
Arthur committed
215
        tokenizer.push_to_hub("test-dynamic-tokenizer", token=self._token)
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282

        tokenizer = AutoTokenizer.from_pretrained(f"{USER}/test-dynamic-tokenizer", trust_remote_code=True)
        # Can't make an isinstance check because the new_model.config is from the FakeConfig class of a dynamic module
        self.assertEqual(tokenizer.__class__.__name__, "CustomTokenizerFast")
        tokenizer = AutoTokenizer.from_pretrained(
            f"{USER}/test-dynamic-tokenizer", use_fast=False, trust_remote_code=True
        )
        # Can't make an isinstance check because the new_model.config is from the FakeConfig class of a dynamic module
        self.assertEqual(tokenizer.__class__.__name__, "CustomTokenizer")


class TrieTest(unittest.TestCase):
    def test_trie(self):
        trie = Trie()
        trie.add("Hello 友達")
        self.assertEqual(trie.data, {"H": {"e": {"l": {"l": {"o": {" ": {"友": {"達": {"": 1}}}}}}}}})
        trie.add("Hello")
        trie.data
        self.assertEqual(trie.data, {"H": {"e": {"l": {"l": {"o": {"": 1, " ": {"友": {"達": {"": 1}}}}}}}}})

    def test_trie_split(self):
        trie = Trie()
        self.assertEqual(trie.split("[CLS] This is a extra_id_100"), ["[CLS] This is a extra_id_100"])
        trie.add("[CLS]")
        trie.add("extra_id_1")
        trie.add("extra_id_100")
        self.assertEqual(trie.split("[CLS] This is a extra_id_100"), ["[CLS]", " This is a ", "extra_id_100"])

    def test_trie_single(self):
        trie = Trie()
        trie.add("A")
        self.assertEqual(trie.split("ABC"), ["A", "BC"])
        self.assertEqual(trie.split("BCA"), ["BC", "A"])

    def test_trie_final(self):
        trie = Trie()
        trie.add("TOKEN]")
        trie.add("[SPECIAL_TOKEN]")
        self.assertEqual(trie.split("This is something [SPECIAL_TOKEN]"), ["This is something ", "[SPECIAL_TOKEN]"])

    def test_trie_subtokens(self):
        trie = Trie()
        trie.add("A")
        trie.add("P")
        trie.add("[SPECIAL_TOKEN]")
        self.assertEqual(trie.split("This is something [SPECIAL_TOKEN]"), ["This is something ", "[SPECIAL_TOKEN]"])

    def test_trie_suffix_tokens(self):
        trie = Trie()
        trie.add("AB")
        trie.add("B")
        trie.add("C")
        self.assertEqual(trie.split("ABC"), ["AB", "C"])

    def test_trie_skip(self):
        trie = Trie()
        trie.add("ABC")
        trie.add("B")
        trie.add("CD")
        self.assertEqual(trie.split("ABCD"), ["ABC", "D"])

    def test_cut_text_hardening(self):
        # Even if the offsets are wrong, we necessarily output correct string
        # parts.
        trie = Trie()
        parts = trie.cut_text("ABC", [0, 0, 2, 1, 2, 3])
        self.assertEqual(parts, ["AB", "C"])
Ahmed Moubtahij's avatar
Ahmed Moubtahij committed
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314


class ExtensionsTrieTest(unittest.TestCase):
    def test_extensions(self):
        # Test searching by prefix
        trie = ExtensionsTrie()
        trie.add("foo")
        trie.add("food")
        trie.add("foodie")
        trie.add("helium")
        self.assertEqual(trie.extensions("foo"), ["foo", "food", "foodie"])
        self.assertEqual(trie.extensions("helium"), ["helium"])

    def test_empty_prefix(self):
        trie = ExtensionsTrie()
        # Test searching with an empty prefix returns all values
        trie.add("hello")
        trie.add("bye")
        self.assertEqual(trie.extensions(""), ["hello", "bye"])

    def test_no_extension_match(self):
        trie = ExtensionsTrie()
        # Test searching for a prefix that doesn't match any key
        with self.assertRaises(KeyError):
            trie.extensions("unknown")

    def test_update_value(self):
        trie = ExtensionsTrie()
        # Test updating the value of an existing key
        trie.add("hi")
        trie.add("hi")
        self.assertEqual(trie.extensions("hi"), ["hi"])