"conda/git@developer.sourcefind.cn:OpenDAS/torch-scatter.git" did not exist on "f95330e2417368a566f5a5a8965ee96a45a4d942"
test_tokenization_utils.py 14.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
# coding=utf-8
# Copyright 2019 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys
import tempfile
import unittest
import unittest.mock as mock
from pathlib import Path

from huggingface_hub import HfFolder, delete_repo
from huggingface_hub.file_download import http_get
from requests.exceptions import HTTPError

from transformers import (
    AlbertTokenizer,
    AutoTokenizer,
    BertTokenizer,
    BertTokenizerFast,
    GPT2TokenizerFast,
    is_tokenizers_available,
)
from transformers.testing_utils import TOKEN, USER, is_staging_test, require_tokenizers
Ahmed Moubtahij's avatar
Ahmed Moubtahij committed
36
from transformers.tokenization_utils import ExtensionsTrie, Trie
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75


sys.path.append(str(Path(__file__).parent.parent / "utils"))

from test_module.custom_tokenization import CustomTokenizer  # noqa E402


if is_tokenizers_available():
    from test_module.custom_tokenization_fast import CustomTokenizerFast


class TokenizerUtilTester(unittest.TestCase):
    def test_cached_files_are_used_when_internet_is_down(self):
        # A mock response for an HTTP head request to emulate server down
        response_mock = mock.Mock()
        response_mock.status_code = 500
        response_mock.headers = {}
        response_mock.raise_for_status.side_effect = HTTPError
        response_mock.json.return_value = {}

        # Download this model to make sure it's in the cache.
        _ = BertTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")

        # Under the mock environment we get a 500 error when trying to reach the tokenizer.
        with mock.patch("requests.Session.request", return_value=response_mock) as mock_head:
            _ = BertTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
            # This check we did call the fake head request
            mock_head.assert_called()

    @require_tokenizers
    def test_cached_files_are_used_when_internet_is_down_missing_files(self):
        # A mock response for an HTTP head request to emulate server down
        response_mock = mock.Mock()
        response_mock.status_code = 500
        response_mock.headers = {}
        response_mock.raise_for_status.side_effect = HTTPError
        response_mock.json.return_value = {}

        # Download this model to make sure it's in the cache.
76
        _ = GPT2TokenizerFast.from_pretrained("openai-community/gpt2")
77
78
79

        # Under the mock environment we get a 500 error when trying to reach the tokenizer.
        with mock.patch("requests.Session.request", return_value=response_mock) as mock_head:
80
            _ = GPT2TokenizerFast.from_pretrained("openai-community/gpt2")
81
82
83
84
85
86
            # This check we did call the fake head request
            mock_head.assert_called()

    def test_legacy_load_from_one_file(self):
        # This test is for deprecated behavior and can be removed in v5
        try:
87
            tmp_file = tempfile.NamedTemporaryFile(delete=False).name
88
            with open(tmp_file, "wb") as f:
89
                http_get("https://huggingface.co/albert/albert-base-v1/resolve/main/spiece.model", f)
90
91
92
93
94
95
96
97
98

            _ = AlbertTokenizer.from_pretrained(tmp_file)
        finally:
            os.remove(tmp_file)

        # Supporting this legacy load introduced a weird bug where the tokenizer would load local files if they are in
        # the current folder and have the right name.
        if os.path.isfile("tokenizer.json"):
            # We skip the test if the user has a `tokenizer.json` in this folder to avoid deleting it.
amyeroberts's avatar
amyeroberts committed
99
            self.skipTest(reason="Skipping test as there is a `tokenizer.json` file in the current folder.")
100
101
102
103
        try:
            with open("tokenizer.json", "wb") as f:
                http_get("https://huggingface.co/hf-internal-testing/tiny-random-bert/blob/main/tokenizer.json", f)
            tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
104
            # The tiny random BERT has a vocab size of 1024, tiny openai-community/gpt2 as a vocab size of 1000
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
            self.assertEqual(tokenizer.vocab_size, 1000)
            # Tokenizer should depend on the remote checkpoint, not the local tokenizer.json file.

        finally:
            os.remove("tokenizer.json")


@is_staging_test
class TokenizerPushToHubTester(unittest.TestCase):
    vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]", "bla", "blou"]

    @classmethod
    def setUpClass(cls):
        cls._token = TOKEN
        HfFolder.save_token(TOKEN)

Yih-Dar's avatar
Yih-Dar committed
121
122
    @staticmethod
    def _try_delete_repo(repo_id, token):
Yih-Dar's avatar
Yih-Dar committed
123
124
        try:
            # Reset repo
Yih-Dar's avatar
Yih-Dar committed
125
            delete_repo(repo_id=repo_id, token=token)
Yih-Dar's avatar
Yih-Dar committed
126
127
        except:  # noqa E722
            pass
128

Yih-Dar's avatar
Yih-Dar committed
129
    def test_push_to_hub(self):
130
        with tempfile.TemporaryDirectory() as tmp_dir:
Yih-Dar's avatar
Yih-Dar committed
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
            try:
                tmp_repo = f"{USER}/test-tokenizer-{Path(tmp_dir).name}"
                vocab_file = os.path.join(tmp_dir, "vocab.txt")
                with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
                    vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
                tokenizer = BertTokenizer(vocab_file)

                tokenizer.push_to_hub(tmp_repo, token=self._token)
                new_tokenizer = BertTokenizer.from_pretrained(tmp_repo)
                self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
            finally:
                # Always (try to) delete the repo.
                self._try_delete_repo(repo_id=tmp_repo, token=self._token)

    def test_push_to_hub_via_save_pretrained(self):
        with tempfile.TemporaryDirectory() as tmp_dir:
            try:
                tmp_repo = f"{USER}/test-tokenizer-{Path(tmp_dir).name}"
                vocab_file = os.path.join(tmp_dir, "vocab.txt")
                with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
                    vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
                tokenizer = BertTokenizer(vocab_file)

                # Push to hub via save_pretrained
                tokenizer.save_pretrained(tmp_dir, repo_id=tmp_repo, push_to_hub=True, token=self._token)

                new_tokenizer = BertTokenizer.from_pretrained(tmp_repo)
                self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
            finally:
                # Always (try to) delete the repo.
                self._try_delete_repo(repo_id=tmp_repo, token=self._token)
162
163
164

    def test_push_to_hub_in_organization(self):
        with tempfile.TemporaryDirectory() as tmp_dir:
Yih-Dar's avatar
Yih-Dar committed
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
            try:
                tmp_repo = f"valid_org/test-tokenizer-{Path(tmp_dir).name}"
                vocab_file = os.path.join(tmp_dir, "vocab.txt")
                with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
                    vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
                tokenizer = BertTokenizer(vocab_file)

                tokenizer.push_to_hub(tmp_repo, token=self._token)
                new_tokenizer = BertTokenizer.from_pretrained(tmp_repo)
                self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
            finally:
                # Always (try to) delete the repo.
                self._try_delete_repo(repo_id=tmp_repo, token=self._token)

    def test_push_to_hub_in_organization_via_save_pretrained(self):
180
        with tempfile.TemporaryDirectory() as tmp_dir:
Yih-Dar's avatar
Yih-Dar committed
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
            try:
                tmp_repo = f"valid_org/test-tokenizer-{Path(tmp_dir).name}"
                vocab_file = os.path.join(tmp_dir, "vocab.txt")
                with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
                    vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
                tokenizer = BertTokenizer(vocab_file)

                # Push to hub via save_pretrained
                tokenizer.save_pretrained(tmp_dir, repo_id=tmp_repo, push_to_hub=True, token=self._token)

                new_tokenizer = BertTokenizer.from_pretrained(tmp_repo)
                self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
            finally:
                # Always (try to) delete the repo.
                self._try_delete_repo(repo_id=tmp_repo, token=self._token)
196
197
198
199

    @require_tokenizers
    def test_push_to_hub_dynamic_tokenizer(self):
        with tempfile.TemporaryDirectory() as tmp_dir:
Yih-Dar's avatar
Yih-Dar committed
200
201
202
203
204
205
206
207
            try:
                tmp_repo = f"{USER}/test-dynamic-tokenizer-{Path(tmp_dir).name}"
                CustomTokenizer.register_for_auto_class()

                vocab_file = os.path.join(tmp_dir, "vocab.txt")
                with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
                    vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
                tokenizer = CustomTokenizer(vocab_file)
208

Yih-Dar's avatar
Yih-Dar committed
209
210
                # No fast custom tokenizer
                tokenizer.push_to_hub(tmp_repo, token=self._token)
211

Yih-Dar's avatar
Yih-Dar committed
212
213
214
215
216
217
                tokenizer = AutoTokenizer.from_pretrained(tmp_repo, trust_remote_code=True)
                # Can't make an isinstance check because the new_model.config is from the CustomTokenizer class of a dynamic module
                self.assertEqual(tokenizer.__class__.__name__, "CustomTokenizer")
            finally:
                # Always (try to) delete the repo.
                self._try_delete_repo(repo_id=tmp_repo, token=self._token)
218

Yih-Dar's avatar
Yih-Dar committed
219
220
    @require_tokenizers
    def test_push_to_hub_dynamic_tokenizer_with_both_slow_and_fast_classes(self):
221
        with tempfile.TemporaryDirectory() as tmp_dir:
Yih-Dar's avatar
Yih-Dar committed
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
            try:
                tmp_repo = f"{USER}/test-dynamic-tokenizer-{Path(tmp_dir).name}"
                CustomTokenizer.register_for_auto_class()

                # Fast and slow custom tokenizer
                CustomTokenizerFast.register_for_auto_class()

                vocab_file = os.path.join(tmp_dir, "vocab.txt")
                with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
                    vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))

                bert_tokenizer = BertTokenizerFast.from_pretrained(tmp_dir)
                bert_tokenizer.save_pretrained(tmp_dir)
                tokenizer = CustomTokenizerFast.from_pretrained(tmp_dir)

                tokenizer.push_to_hub(tmp_repo, token=self._token)

                tokenizer = AutoTokenizer.from_pretrained(tmp_repo, trust_remote_code=True)
                # Can't make an isinstance check because the new_model.config is from the FakeConfig class of a dynamic module
                self.assertEqual(tokenizer.__class__.__name__, "CustomTokenizerFast")
                tokenizer = AutoTokenizer.from_pretrained(tmp_repo, use_fast=False, trust_remote_code=True)
                # Can't make an isinstance check because the new_model.config is from the FakeConfig class of a dynamic module
                self.assertEqual(tokenizer.__class__.__name__, "CustomTokenizer")
            finally:
                # Always (try to) delete the repo.
                self._try_delete_repo(repo_id=tmp_repo, token=self._token)
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305


class TrieTest(unittest.TestCase):
    def test_trie(self):
        trie = Trie()
        trie.add("Hello 友達")
        self.assertEqual(trie.data, {"H": {"e": {"l": {"l": {"o": {" ": {"友": {"達": {"": 1}}}}}}}}})
        trie.add("Hello")
        trie.data
        self.assertEqual(trie.data, {"H": {"e": {"l": {"l": {"o": {"": 1, " ": {"友": {"達": {"": 1}}}}}}}}})

    def test_trie_split(self):
        trie = Trie()
        self.assertEqual(trie.split("[CLS] This is a extra_id_100"), ["[CLS] This is a extra_id_100"])
        trie.add("[CLS]")
        trie.add("extra_id_1")
        trie.add("extra_id_100")
        self.assertEqual(trie.split("[CLS] This is a extra_id_100"), ["[CLS]", " This is a ", "extra_id_100"])

    def test_trie_single(self):
        trie = Trie()
        trie.add("A")
        self.assertEqual(trie.split("ABC"), ["A", "BC"])
        self.assertEqual(trie.split("BCA"), ["BC", "A"])

    def test_trie_final(self):
        trie = Trie()
        trie.add("TOKEN]")
        trie.add("[SPECIAL_TOKEN]")
        self.assertEqual(trie.split("This is something [SPECIAL_TOKEN]"), ["This is something ", "[SPECIAL_TOKEN]"])

    def test_trie_subtokens(self):
        trie = Trie()
        trie.add("A")
        trie.add("P")
        trie.add("[SPECIAL_TOKEN]")
        self.assertEqual(trie.split("This is something [SPECIAL_TOKEN]"), ["This is something ", "[SPECIAL_TOKEN]"])

    def test_trie_suffix_tokens(self):
        trie = Trie()
        trie.add("AB")
        trie.add("B")
        trie.add("C")
        self.assertEqual(trie.split("ABC"), ["AB", "C"])

    def test_trie_skip(self):
        trie = Trie()
        trie.add("ABC")
        trie.add("B")
        trie.add("CD")
        self.assertEqual(trie.split("ABCD"), ["ABC", "D"])

    def test_cut_text_hardening(self):
        # Even if the offsets are wrong, we necessarily output correct string
        # parts.
        trie = Trie()
        parts = trie.cut_text("ABC", [0, 0, 2, 1, 2, 3])
        self.assertEqual(parts, ["AB", "C"])
Ahmed Moubtahij's avatar
Ahmed Moubtahij committed
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337


class ExtensionsTrieTest(unittest.TestCase):
    def test_extensions(self):
        # Test searching by prefix
        trie = ExtensionsTrie()
        trie.add("foo")
        trie.add("food")
        trie.add("foodie")
        trie.add("helium")
        self.assertEqual(trie.extensions("foo"), ["foo", "food", "foodie"])
        self.assertEqual(trie.extensions("helium"), ["helium"])

    def test_empty_prefix(self):
        trie = ExtensionsTrie()
        # Test searching with an empty prefix returns all values
        trie.add("hello")
        trie.add("bye")
        self.assertEqual(trie.extensions(""), ["hello", "bye"])

    def test_no_extension_match(self):
        trie = ExtensionsTrie()
        # Test searching for a prefix that doesn't match any key
        with self.assertRaises(KeyError):
            trie.extensions("unknown")

    def test_update_value(self):
        trie = ExtensionsTrie()
        # Test updating the value of an existing key
        trie.add("hi")
        trie.add("hi")
        self.assertEqual(trie.extensions("hi"), ["hi"])