test_tokenization_bert_japanese.py 11.1 KB
Newer Older
1
# coding=utf-8
Sylvain Gugger's avatar
Sylvain Gugger committed
2
# Copyright 2020 The HuggingFace Team. All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Aymeric Augustin's avatar
Aymeric Augustin committed
15

16
17

import os
18
import pickle
19
import unittest
20

21
from transformers import AutoTokenizer
Sylvain Gugger's avatar
Sylvain Gugger committed
22
from transformers.models.bert_japanese.tokenization_bert_japanese import (
Aymeric Augustin's avatar
Aymeric Augustin committed
23
    VOCAB_FILES_NAMES,
24
25
    BertJapaneseTokenizer,
    CharacterTokenizer,
Aymeric Augustin's avatar
Aymeric Augustin committed
26
    MecabTokenizer,
27
    WordpieceTokenizer,
28
)
Sylvain Gugger's avatar
Sylvain Gugger committed
29
from transformers.testing_utils import custom_tokenizers
30

31
from .test_tokenization_common import TokenizerTesterMixin
32
33


34
@custom_tokenizers
35
class BertJapaneseTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
36
37

    tokenizer_class = BertJapaneseTokenizer
38
    space_between_special_tokens = True
39
40

    def setUp(self):
Julien Chaumond's avatar
Julien Chaumond committed
41
        super().setUp()
42

43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
        vocab_tokens = [
            "[UNK]",
            "[CLS]",
            "[SEP]",
            "こんにちは",
            "こん",
            "にちは",
            "ばんは",
            "##こん",
            "##にちは",
            "##ばんは",
            "世界",
            "##世界",
            "、",
            "##、",
            "。",
            "##。",
        ]
61
62
63
64
65

        self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
        with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
            vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))

66
    def get_input_output_texts(self, tokenizer):
67
68
        input_text = "こんにちは、世界。 \nこんばんは、世界。"
        output_text = "こんにちは 、 世界 。 こんばんは 、 世界 。"
69
70
        return input_text, output_text

71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
    def get_clean_sequence(self, tokenizer):
        input_text, output_text = self.get_input_output_texts(tokenizer)
        ids = tokenizer.encode(output_text, add_special_tokens=False)
        text = tokenizer.decode(ids, clean_up_tokenization_spaces=False)
        return text, ids

    def test_pretokenized_inputs(self):
        pass  # TODO add if relevant

    def test_maximum_encoding_length_pair_input(self):
        pass  # TODO add if relevant

    def test_maximum_encoding_length_single_input(self):
        pass  # TODO add if relevant

86
87
88
    def test_full_tokenizer(self):
        tokenizer = self.tokenizer_class(self.vocab_file)

89
90
91
        tokens = tokenizer.tokenize("こんにちは、世界。\nこんばんは、世界。")
        self.assertListEqual(tokens, ["こんにちは", "、", "世界", "。", "こん", "##ばんは", "、", "世界", "。"])
        self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [3, 12, 10, 14, 4, 9, 12, 10, 14])
92

93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
    def test_pickle_mecab_tokenizer(self):
        tokenizer = self.tokenizer_class(self.vocab_file, word_tokenizer_type="mecab")
        self.assertIsNotNone(tokenizer)

        text = "こんにちは、世界。\nこんばんは、世界。"
        tokens = tokenizer.tokenize(text)
        self.assertListEqual(tokens, ["こんにちは", "、", "世界", "。", "こん", "##ばんは", "、", "世界", "。"])
        self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [3, 12, 10, 14, 4, 9, 12, 10, 14])

        filename = os.path.join(self.tmpdirname, "tokenizer.bin")
        with open(filename, "wb") as handle:
            pickle.dump(tokenizer, handle)

        with open(filename, "rb") as handle:
            tokenizer_new = pickle.load(handle)

        tokens_loaded = tokenizer_new.tokenize(text)

        self.assertListEqual(tokens, tokens_loaded)

113
114
    def test_mecab_tokenizer_ipadic(self):
        tokenizer = MecabTokenizer(mecab_dic="ipadic")
115
116

        self.assertListEqual(
117
118
119
            tokenizer.tokenize(" \tアップルストアでiPhone8 が  \n 発売された 。  "),
            ["アップルストア", "で", "iPhone", "8", "が", "発売", "さ", "れ", "た", "。"],
        )
120

121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
    def test_mecab_tokenizer_unidic_lite(self):
        try:
            tokenizer = MecabTokenizer(mecab_dic="unidic_lite")
        except ModuleNotFoundError:
            return

        self.assertListEqual(
            tokenizer.tokenize(" \tアップルストアでiPhone8 が  \n 発売された 。  "),
            ["アップル", "ストア", "で", "iPhone", "8", "が", "発売", "さ", "れ", "た", "。"],
        )

    def test_mecab_tokenizer_unidic(self):
        try:
            tokenizer = MecabTokenizer(mecab_dic="unidic")
        except ModuleNotFoundError:
            return

        self.assertListEqual(
            tokenizer.tokenize(" \tアップルストアでiPhone8 が  \n 発売された 。  "),
            ["アップル", "ストア", "で", "iPhone", "8", "が", "発売", "さ", "れ", "た", "。"],
        )

143
    def test_mecab_tokenizer_lower(self):
144
        tokenizer = MecabTokenizer(do_lower_case=True, mecab_dic="ipadic")
145
146

        self.assertListEqual(
147
148
149
            tokenizer.tokenize(" \tアップルストアでiPhone8 が  \n 発売された 。  "),
            ["アップルストア", "で", "iphone", "8", "が", "発売", "さ", "れ", "た", "。"],
        )
150

151
152
153
154
155
156
157
158
159
160
161
162
163
164
    def test_mecab_tokenizer_with_option(self):
        try:
            tokenizer = MecabTokenizer(
                do_lower_case=True, normalize_text=False, mecab_option="-d /usr/local/lib/mecab/dic/jumandic"
            )
        except RuntimeError:
            # if dict doesn't exist in the system, previous code raises this error.
            return

        self.assertListEqual(
            tokenizer.tokenize(" \tアップルストアでiPhone8 が  \n 発売された 。  "),
            ["アップルストア", "で", "iPhone", "8", "が", "発売", "さ", "れた", "\u3000", "。"],
        )

165
    def test_mecab_tokenizer_no_normalize(self):
166
        tokenizer = MecabTokenizer(normalize_text=False, mecab_dic="ipadic")
167
168

        self.assertListEqual(
169
170
171
            tokenizer.tokenize(" \tアップルストアでiPhone8 が  \n 発売された 。  "),
            ["アップルストア", "で", "iPhone", "8", "が", "発売", "さ", "れ", "た", " ", "。"],
        )
172
173

    def test_wordpiece_tokenizer(self):
174
        vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "こんにちは", "こん", "にちは" "ばんは", "##こん", "##にちは", "##ばんは"]
175
176
177
178

        vocab = {}
        for (i, token) in enumerate(vocab_tokens):
            vocab[token] = i
179
        tokenizer = WordpieceTokenizer(vocab=vocab, unk_token="[UNK]")
180

181
        self.assertListEqual(tokenizer.tokenize(""), [])
182

183
        self.assertListEqual(tokenizer.tokenize("こんにちは"), ["こんにちは"])
184

185
        self.assertListEqual(tokenizer.tokenize("こんばんは"), ["こん", "##ばんは"])
186

187
        self.assertListEqual(tokenizer.tokenize("こんばんは こんばんにちは こんにちは"), ["こん", "##ばんは", "[UNK]", "こんにちは"])
188
189

    def test_sequence_builders(self):
190
        tokenizer = self.tokenizer_class.from_pretrained("cl-tohoku/bert-base-japanese")
191

192
193
        text = tokenizer.encode("ありがとう。", add_special_tokens=False)
        text_2 = tokenizer.encode("どういたしまして。", add_special_tokens=False)
194
195
196
197
198
199
200
201
202

        encoded_sentence = tokenizer.build_inputs_with_special_tokens(text)
        encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2)

        # 2 is for "[CLS]", 3 is for "[SEP]"
        assert encoded_sentence == [2] + text + [3]
        assert encoded_pair == [2] + text + [3] + text_2 + [3]


Julien Chaumond's avatar
Julien Chaumond committed
203
@custom_tokenizers
204
class BertJapaneseCharacterTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
205
206
207
208

    tokenizer_class = BertJapaneseTokenizer

    def setUp(self):
Julien Chaumond's avatar
Julien Chaumond committed
209
        super().setUp()
210

211
        vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "こ", "ん", "に", "ち", "は", "ば", "世", "界", "、", "。"]
212
213
214
215
216
217

        self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
        with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
            vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))

    def get_tokenizer(self, **kwargs):
218
        return BertJapaneseTokenizer.from_pretrained(self.tmpdirname, subword_tokenizer_type="character", **kwargs)
219

220
    def get_input_output_texts(self, tokenizer):
221
222
        input_text = "こんにちは、世界。 \nこんばんは、世界。"
        output_text = "こ ん に ち は 、 世 界 。 こ ん ば ん は 、 世 界 。"
223
224
        return input_text, output_text

225
226
227
228
229
230
231
232
233
    def test_pretokenized_inputs(self):
        pass  # TODO add if relevant

    def test_maximum_encoding_length_pair_input(self):
        pass  # TODO add if relevant

    def test_maximum_encoding_length_single_input(self):
        pass  # TODO add if relevant

234
    def test_full_tokenizer(self):
235
        tokenizer = self.tokenizer_class(self.vocab_file, subword_tokenizer_type="character")
236

237
238
239
240
241
242
243
        tokens = tokenizer.tokenize("こんにちは、世界。 \nこんばんは、世界。")
        self.assertListEqual(
            tokens, ["こ", "ん", "に", "ち", "は", "、", "世", "界", "。", "こ", "ん", "ば", "ん", "は", "、", "世", "界", "。"]
        )
        self.assertListEqual(
            tokenizer.convert_tokens_to_ids(tokens), [3, 4, 5, 6, 7, 11, 9, 10, 12, 3, 4, 8, 4, 7, 11, 9, 10, 12]
        )
244
245

    def test_character_tokenizer(self):
246
        vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "こ", "ん", "に", "ち", "は", "ば", "世", "界" "、", "。"]
247
248
249
250

        vocab = {}
        for (i, token) in enumerate(vocab_tokens):
            vocab[token] = i
251
        tokenizer = CharacterTokenizer(vocab=vocab, unk_token="[UNK]")
252

253
        self.assertListEqual(tokenizer.tokenize(""), [])
254

255
        self.assertListEqual(tokenizer.tokenize("こんにちは"), ["こ", "ん", "に", "ち", "は"])
256

257
        self.assertListEqual(tokenizer.tokenize("こんにちほ"), ["こ", "ん", "に", "ち", "[UNK]"])
258
259

    def test_sequence_builders(self):
260
        tokenizer = self.tokenizer_class.from_pretrained("cl-tohoku/bert-base-japanese-char")
261

262
263
        text = tokenizer.encode("ありがとう。", add_special_tokens=False)
        text_2 = tokenizer.encode("どういたしまして。", add_special_tokens=False)
264
265
266
267
268
269
270

        encoded_sentence = tokenizer.build_inputs_with_special_tokens(text)
        encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2)

        # 2 is for "[CLS]", 3 is for "[SEP]"
        assert encoded_sentence == [2] + text + [3]
        assert encoded_pair == [2] + text + [3] + text_2 + [3]
271
272
273
274
275
276
277
278


@custom_tokenizers
class AutoTokenizerCustomTest(unittest.TestCase):
    def test_tokenizer_bert_japanese(self):
        EXAMPLE_BERT_JAPANESE_ID = "cl-tohoku/bert-base-japanese"
        tokenizer = AutoTokenizer.from_pretrained(EXAMPLE_BERT_JAPANESE_ID)
        self.assertIsInstance(tokenizer, BertJapaneseTokenizer)