test_tokenization_bert_japanese.py 12.1 KB
Newer Older
1
# coding=utf-8
Sylvain Gugger's avatar
Sylvain Gugger committed
2
# Copyright 2020 The HuggingFace Team. All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Aymeric Augustin's avatar
Aymeric Augustin committed
15

16
17

import os
18
import pickle
19
import unittest
20

21
from transformers import AutoTokenizer
Sylvain Gugger's avatar
Sylvain Gugger committed
22
from transformers.models.bert_japanese.tokenization_bert_japanese import (
Aymeric Augustin's avatar
Aymeric Augustin committed
23
    VOCAB_FILES_NAMES,
24
    BertJapaneseTokenizer,
25
    BertTokenizer,
26
    CharacterTokenizer,
Aymeric Augustin's avatar
Aymeric Augustin committed
27
    MecabTokenizer,
28
    WordpieceTokenizer,
29
)
Sylvain Gugger's avatar
Sylvain Gugger committed
30
from transformers.testing_utils import custom_tokenizers
31

32
from .test_tokenization_common import TokenizerTesterMixin
33
34


35
@custom_tokenizers
36
class BertJapaneseTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
37
38

    tokenizer_class = BertJapaneseTokenizer
39
    test_rust_tokenizer = False
40
    space_between_special_tokens = True
41
42

    def setUp(self):
Julien Chaumond's avatar
Julien Chaumond committed
43
        super().setUp()
44

45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
        vocab_tokens = [
            "[UNK]",
            "[CLS]",
            "[SEP]",
            "こんにちは",
            "こん",
            "にちは",
            "ばんは",
            "##こん",
            "##にちは",
            "##ばんは",
            "世界",
            "##世界",
            "、",
            "##、",
            "。",
            "##。",
        ]
63
64
65
66
67

        self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
        with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
            vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))

68
    def get_input_output_texts(self, tokenizer):
69
70
        input_text = "こんにちは、世界。 \nこんばんは、世界。"
        output_text = "こんにちは 、 世界 。 こんばんは 、 世界 。"
71
72
        return input_text, output_text

73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
    def get_clean_sequence(self, tokenizer):
        input_text, output_text = self.get_input_output_texts(tokenizer)
        ids = tokenizer.encode(output_text, add_special_tokens=False)
        text = tokenizer.decode(ids, clean_up_tokenization_spaces=False)
        return text, ids

    def test_pretokenized_inputs(self):
        pass  # TODO add if relevant

    def test_maximum_encoding_length_pair_input(self):
        pass  # TODO add if relevant

    def test_maximum_encoding_length_single_input(self):
        pass  # TODO add if relevant

88
89
90
    def test_full_tokenizer(self):
        tokenizer = self.tokenizer_class(self.vocab_file)

91
92
93
        tokens = tokenizer.tokenize("こんにちは、世界。\nこんばんは、世界。")
        self.assertListEqual(tokens, ["こんにちは", "、", "世界", "。", "こん", "##ばんは", "、", "世界", "。"])
        self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [3, 12, 10, 14, 4, 9, 12, 10, 14])
94

95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
    def test_pickle_mecab_tokenizer(self):
        tokenizer = self.tokenizer_class(self.vocab_file, word_tokenizer_type="mecab")
        self.assertIsNotNone(tokenizer)

        text = "こんにちは、世界。\nこんばんは、世界。"
        tokens = tokenizer.tokenize(text)
        self.assertListEqual(tokens, ["こんにちは", "、", "世界", "。", "こん", "##ばんは", "、", "世界", "。"])
        self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [3, 12, 10, 14, 4, 9, 12, 10, 14])

        filename = os.path.join(self.tmpdirname, "tokenizer.bin")
        with open(filename, "wb") as handle:
            pickle.dump(tokenizer, handle)

        with open(filename, "rb") as handle:
            tokenizer_new = pickle.load(handle)

        tokens_loaded = tokenizer_new.tokenize(text)

        self.assertListEqual(tokens, tokens_loaded)

115
116
    def test_mecab_tokenizer_ipadic(self):
        tokenizer = MecabTokenizer(mecab_dic="ipadic")
117
118

        self.assertListEqual(
119
120
121
            tokenizer.tokenize(" \tアップルストアでiPhone8 が  \n 発売された 。  "),
            ["アップルストア", "で", "iPhone", "8", "が", "発売", "さ", "れ", "た", "。"],
        )
122

123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
    def test_mecab_tokenizer_unidic_lite(self):
        try:
            tokenizer = MecabTokenizer(mecab_dic="unidic_lite")
        except ModuleNotFoundError:
            return

        self.assertListEqual(
            tokenizer.tokenize(" \tアップルストアでiPhone8 が  \n 発売された 。  "),
            ["アップル", "ストア", "で", "iPhone", "8", "が", "発売", "さ", "れ", "た", "。"],
        )

    def test_mecab_tokenizer_unidic(self):
        try:
            tokenizer = MecabTokenizer(mecab_dic="unidic")
        except ModuleNotFoundError:
            return

        self.assertListEqual(
            tokenizer.tokenize(" \tアップルストアでiPhone8 が  \n 発売された 。  "),
            ["アップル", "ストア", "で", "iPhone", "8", "が", "発売", "さ", "れ", "た", "。"],
        )

145
    def test_mecab_tokenizer_lower(self):
146
        tokenizer = MecabTokenizer(do_lower_case=True, mecab_dic="ipadic")
147
148

        self.assertListEqual(
149
150
151
            tokenizer.tokenize(" \tアップルストアでiPhone8 が  \n 発売された 。  "),
            ["アップルストア", "で", "iphone", "8", "が", "発売", "さ", "れ", "た", "。"],
        )
152

153
154
155
156
157
158
159
160
161
162
163
164
165
166
    def test_mecab_tokenizer_with_option(self):
        try:
            tokenizer = MecabTokenizer(
                do_lower_case=True, normalize_text=False, mecab_option="-d /usr/local/lib/mecab/dic/jumandic"
            )
        except RuntimeError:
            # if dict doesn't exist in the system, previous code raises this error.
            return

        self.assertListEqual(
            tokenizer.tokenize(" \tアップルストアでiPhone8 が  \n 発売された 。  "),
            ["アップルストア", "で", "iPhone", "8", "が", "発売", "さ", "れた", "\u3000", "。"],
        )

167
    def test_mecab_tokenizer_no_normalize(self):
168
        tokenizer = MecabTokenizer(normalize_text=False, mecab_dic="ipadic")
169
170

        self.assertListEqual(
171
172
173
            tokenizer.tokenize(" \tアップルストアでiPhone8 が  \n 発売された 。  "),
            ["アップルストア", "で", "iPhone", "8", "が", "発売", "さ", "れ", "た", " ", "。"],
        )
174
175

    def test_wordpiece_tokenizer(self):
176
        vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "こんにちは", "こん", "にちは" "ばんは", "##こん", "##にちは", "##ばんは"]
177
178
179
180

        vocab = {}
        for (i, token) in enumerate(vocab_tokens):
            vocab[token] = i
181
        tokenizer = WordpieceTokenizer(vocab=vocab, unk_token="[UNK]")
182

183
        self.assertListEqual(tokenizer.tokenize(""), [])
184

185
        self.assertListEqual(tokenizer.tokenize("こんにちは"), ["こんにちは"])
186

187
        self.assertListEqual(tokenizer.tokenize("こんばんは"), ["こん", "##ばんは"])
188

189
        self.assertListEqual(tokenizer.tokenize("こんばんは こんばんにちは こんにちは"), ["こん", "##ばんは", "[UNK]", "こんにちは"])
190
191

    def test_sequence_builders(self):
192
        tokenizer = self.tokenizer_class.from_pretrained("cl-tohoku/bert-base-japanese")
193

194
195
        text = tokenizer.encode("ありがとう。", add_special_tokens=False)
        text_2 = tokenizer.encode("どういたしまして。", add_special_tokens=False)
196
197
198
199
200
201
202
203
204

        encoded_sentence = tokenizer.build_inputs_with_special_tokens(text)
        encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2)

        # 2 is for "[CLS]", 3 is for "[SEP]"
        assert encoded_sentence == [2] + text + [3]
        assert encoded_pair == [2] + text + [3] + text_2 + [3]


Julien Chaumond's avatar
Julien Chaumond committed
205
@custom_tokenizers
206
class BertJapaneseCharacterTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
207
208

    tokenizer_class = BertJapaneseTokenizer
209
    test_rust_tokenizer = False
210
211

    def setUp(self):
Julien Chaumond's avatar
Julien Chaumond committed
212
        super().setUp()
213

214
        vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "こ", "ん", "に", "ち", "は", "ば", "世", "界", "、", "。"]
215
216
217
218
219
220

        self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
        with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
            vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))

    def get_tokenizer(self, **kwargs):
221
        return BertJapaneseTokenizer.from_pretrained(self.tmpdirname, subword_tokenizer_type="character", **kwargs)
222

223
    def get_input_output_texts(self, tokenizer):
224
225
        input_text = "こんにちは、世界。 \nこんばんは、世界。"
        output_text = "こ ん に ち は 、 世 界 。 こ ん ば ん は 、 世 界 。"
226
227
        return input_text, output_text

228
229
230
231
232
233
234
235
236
    def test_pretokenized_inputs(self):
        pass  # TODO add if relevant

    def test_maximum_encoding_length_pair_input(self):
        pass  # TODO add if relevant

    def test_maximum_encoding_length_single_input(self):
        pass  # TODO add if relevant

237
    def test_full_tokenizer(self):
238
        tokenizer = self.tokenizer_class(self.vocab_file, subword_tokenizer_type="character")
239

240
241
242
243
244
245
246
        tokens = tokenizer.tokenize("こんにちは、世界。 \nこんばんは、世界。")
        self.assertListEqual(
            tokens, ["こ", "ん", "に", "ち", "は", "、", "世", "界", "。", "こ", "ん", "ば", "ん", "は", "、", "世", "界", "。"]
        )
        self.assertListEqual(
            tokenizer.convert_tokens_to_ids(tokens), [3, 4, 5, 6, 7, 11, 9, 10, 12, 3, 4, 8, 4, 7, 11, 9, 10, 12]
        )
247
248

    def test_character_tokenizer(self):
249
        vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "こ", "ん", "に", "ち", "は", "ば", "世", "界" "、", "。"]
250
251
252
253

        vocab = {}
        for (i, token) in enumerate(vocab_tokens):
            vocab[token] = i
254
        tokenizer = CharacterTokenizer(vocab=vocab, unk_token="[UNK]")
255

256
        self.assertListEqual(tokenizer.tokenize(""), [])
257

258
        self.assertListEqual(tokenizer.tokenize("こんにちは"), ["こ", "ん", "に", "ち", "は"])
259

260
        self.assertListEqual(tokenizer.tokenize("こんにちほ"), ["こ", "ん", "に", "ち", "[UNK]"])
261
262

    def test_sequence_builders(self):
263
        tokenizer = self.tokenizer_class.from_pretrained("cl-tohoku/bert-base-japanese-char")
264

265
266
        text = tokenizer.encode("ありがとう。", add_special_tokens=False)
        text_2 = tokenizer.encode("どういたしまして。", add_special_tokens=False)
267
268
269
270
271
272
273

        encoded_sentence = tokenizer.build_inputs_with_special_tokens(text)
        encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2)

        # 2 is for "[CLS]", 3 is for "[SEP]"
        assert encoded_sentence == [2] + text + [3]
        assert encoded_pair == [2] + text + [3] + text_2 + [3]
274
275
276
277
278
279
280
281


@custom_tokenizers
class AutoTokenizerCustomTest(unittest.TestCase):
    def test_tokenizer_bert_japanese(self):
        EXAMPLE_BERT_JAPANESE_ID = "cl-tohoku/bert-base-japanese"
        tokenizer = AutoTokenizer.from_pretrained(EXAMPLE_BERT_JAPANESE_ID)
        self.assertIsInstance(tokenizer, BertJapaneseTokenizer)
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301


class BertTokenizerMismatchTest(unittest.TestCase):
    def test_tokenizer_mismatch_warning(self):
        EXAMPLE_BERT_JAPANESE_ID = "cl-tohoku/bert-base-japanese"
        with self.assertLogs("transformers", level="WARNING") as cm:
            BertTokenizer.from_pretrained(EXAMPLE_BERT_JAPANESE_ID)
            self.assertTrue(
                cm.records[0].message.startswith(
                    "The tokenizer class you load from this checkpoint is not the same type as the class this function is called from."
                )
            )
        EXAMPLE_BERT_ID = "bert-base-cased"
        with self.assertLogs("transformers", level="WARNING") as cm:
            BertJapaneseTokenizer.from_pretrained(EXAMPLE_BERT_ID)
            self.assertTrue(
                cm.records[0].message.startswith(
                    "The tokenizer class you load from this checkpoint is not the same type as the class this function is called from."
                )
            )