test_tokenization_bert_japanese.py 11.2 KB
Newer Older
1
# coding=utf-8
Sylvain Gugger's avatar
Sylvain Gugger committed
2
# Copyright 2020 The HuggingFace Team. All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Aymeric Augustin's avatar
Aymeric Augustin committed
15

16
17

import os
18
import pickle
19
import unittest
20

21
from transformers import AutoTokenizer
Sylvain Gugger's avatar
Sylvain Gugger committed
22
from transformers.models.bert_japanese.tokenization_bert_japanese import (
Aymeric Augustin's avatar
Aymeric Augustin committed
23
    VOCAB_FILES_NAMES,
24
25
    BertJapaneseTokenizer,
    CharacterTokenizer,
Aymeric Augustin's avatar
Aymeric Augustin committed
26
    MecabTokenizer,
27
    WordpieceTokenizer,
28
)
Sylvain Gugger's avatar
Sylvain Gugger committed
29
from transformers.testing_utils import custom_tokenizers
30

31
from .test_tokenization_common import TokenizerTesterMixin
32
33


34
@custom_tokenizers
35
class BertJapaneseTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
36
37

    tokenizer_class = BertJapaneseTokenizer
38
    test_rust_tokenizer = False
39
    space_between_special_tokens = True
40
41

    def setUp(self):
Julien Chaumond's avatar
Julien Chaumond committed
42
        super().setUp()
43

44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
        vocab_tokens = [
            "[UNK]",
            "[CLS]",
            "[SEP]",
            "こんにちは",
            "こん",
            "にちは",
            "ばんは",
            "##こん",
            "##にちは",
            "##ばんは",
            "世界",
            "##世界",
            "、",
            "##、",
            "。",
            "##。",
        ]
62
63
64
65
66

        self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
        with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
            vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))

67
    def get_input_output_texts(self, tokenizer):
68
69
        input_text = "こんにちは、世界。 \nこんばんは、世界。"
        output_text = "こんにちは 、 世界 。 こんばんは 、 世界 。"
70
71
        return input_text, output_text

72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
    def get_clean_sequence(self, tokenizer):
        input_text, output_text = self.get_input_output_texts(tokenizer)
        ids = tokenizer.encode(output_text, add_special_tokens=False)
        text = tokenizer.decode(ids, clean_up_tokenization_spaces=False)
        return text, ids

    def test_pretokenized_inputs(self):
        pass  # TODO add if relevant

    def test_maximum_encoding_length_pair_input(self):
        pass  # TODO add if relevant

    def test_maximum_encoding_length_single_input(self):
        pass  # TODO add if relevant

87
88
89
    def test_full_tokenizer(self):
        tokenizer = self.tokenizer_class(self.vocab_file)

90
91
92
        tokens = tokenizer.tokenize("こんにちは、世界。\nこんばんは、世界。")
        self.assertListEqual(tokens, ["こんにちは", "、", "世界", "。", "こん", "##ばんは", "、", "世界", "。"])
        self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [3, 12, 10, 14, 4, 9, 12, 10, 14])
93

94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
    def test_pickle_mecab_tokenizer(self):
        tokenizer = self.tokenizer_class(self.vocab_file, word_tokenizer_type="mecab")
        self.assertIsNotNone(tokenizer)

        text = "こんにちは、世界。\nこんばんは、世界。"
        tokens = tokenizer.tokenize(text)
        self.assertListEqual(tokens, ["こんにちは", "、", "世界", "。", "こん", "##ばんは", "、", "世界", "。"])
        self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [3, 12, 10, 14, 4, 9, 12, 10, 14])

        filename = os.path.join(self.tmpdirname, "tokenizer.bin")
        with open(filename, "wb") as handle:
            pickle.dump(tokenizer, handle)

        with open(filename, "rb") as handle:
            tokenizer_new = pickle.load(handle)

        tokens_loaded = tokenizer_new.tokenize(text)

        self.assertListEqual(tokens, tokens_loaded)

114
115
    def test_mecab_tokenizer_ipadic(self):
        tokenizer = MecabTokenizer(mecab_dic="ipadic")
116
117

        self.assertListEqual(
118
119
120
            tokenizer.tokenize(" \tアップルストアでiPhone8 が  \n 発売された 。  "),
            ["アップルストア", "で", "iPhone", "8", "が", "発売", "さ", "れ", "た", "。"],
        )
121

122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
    def test_mecab_tokenizer_unidic_lite(self):
        try:
            tokenizer = MecabTokenizer(mecab_dic="unidic_lite")
        except ModuleNotFoundError:
            return

        self.assertListEqual(
            tokenizer.tokenize(" \tアップルストアでiPhone8 が  \n 発売された 。  "),
            ["アップル", "ストア", "で", "iPhone", "8", "が", "発売", "さ", "れ", "た", "。"],
        )

    def test_mecab_tokenizer_unidic(self):
        try:
            tokenizer = MecabTokenizer(mecab_dic="unidic")
        except ModuleNotFoundError:
            return

        self.assertListEqual(
            tokenizer.tokenize(" \tアップルストアでiPhone8 が  \n 発売された 。  "),
            ["アップル", "ストア", "で", "iPhone", "8", "が", "発売", "さ", "れ", "た", "。"],
        )

144
    def test_mecab_tokenizer_lower(self):
145
        tokenizer = MecabTokenizer(do_lower_case=True, mecab_dic="ipadic")
146
147

        self.assertListEqual(
148
149
150
            tokenizer.tokenize(" \tアップルストアでiPhone8 が  \n 発売された 。  "),
            ["アップルストア", "で", "iphone", "8", "が", "発売", "さ", "れ", "た", "。"],
        )
151

152
153
154
155
156
157
158
159
160
161
162
163
164
165
    def test_mecab_tokenizer_with_option(self):
        try:
            tokenizer = MecabTokenizer(
                do_lower_case=True, normalize_text=False, mecab_option="-d /usr/local/lib/mecab/dic/jumandic"
            )
        except RuntimeError:
            # if dict doesn't exist in the system, previous code raises this error.
            return

        self.assertListEqual(
            tokenizer.tokenize(" \tアップルストアでiPhone8 が  \n 発売された 。  "),
            ["アップルストア", "で", "iPhone", "8", "が", "発売", "さ", "れた", "\u3000", "。"],
        )

166
    def test_mecab_tokenizer_no_normalize(self):
167
        tokenizer = MecabTokenizer(normalize_text=False, mecab_dic="ipadic")
168
169

        self.assertListEqual(
170
171
172
            tokenizer.tokenize(" \tアップルストアでiPhone8 が  \n 発売された 。  "),
            ["アップルストア", "で", "iPhone", "8", "が", "発売", "さ", "れ", "た", " ", "。"],
        )
173
174

    def test_wordpiece_tokenizer(self):
175
        vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "こんにちは", "こん", "にちは" "ばんは", "##こん", "##にちは", "##ばんは"]
176
177
178
179

        vocab = {}
        for (i, token) in enumerate(vocab_tokens):
            vocab[token] = i
180
        tokenizer = WordpieceTokenizer(vocab=vocab, unk_token="[UNK]")
181

182
        self.assertListEqual(tokenizer.tokenize(""), [])
183

184
        self.assertListEqual(tokenizer.tokenize("こんにちは"), ["こんにちは"])
185

186
        self.assertListEqual(tokenizer.tokenize("こんばんは"), ["こん", "##ばんは"])
187

188
        self.assertListEqual(tokenizer.tokenize("こんばんは こんばんにちは こんにちは"), ["こん", "##ばんは", "[UNK]", "こんにちは"])
189
190

    def test_sequence_builders(self):
191
        tokenizer = self.tokenizer_class.from_pretrained("cl-tohoku/bert-base-japanese")
192

193
194
        text = tokenizer.encode("ありがとう。", add_special_tokens=False)
        text_2 = tokenizer.encode("どういたしまして。", add_special_tokens=False)
195
196
197
198
199
200
201
202
203

        encoded_sentence = tokenizer.build_inputs_with_special_tokens(text)
        encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2)

        # 2 is for "[CLS]", 3 is for "[SEP]"
        assert encoded_sentence == [2] + text + [3]
        assert encoded_pair == [2] + text + [3] + text_2 + [3]


Julien Chaumond's avatar
Julien Chaumond committed
204
@custom_tokenizers
205
class BertJapaneseCharacterTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
206
207

    tokenizer_class = BertJapaneseTokenizer
208
    test_rust_tokenizer = False
209
210

    def setUp(self):
Julien Chaumond's avatar
Julien Chaumond committed
211
        super().setUp()
212

213
        vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "こ", "ん", "に", "ち", "は", "ば", "世", "界", "、", "。"]
214
215
216
217
218
219

        self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
        with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
            vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))

    def get_tokenizer(self, **kwargs):
220
        return BertJapaneseTokenizer.from_pretrained(self.tmpdirname, subword_tokenizer_type="character", **kwargs)
221

222
    def get_input_output_texts(self, tokenizer):
223
224
        input_text = "こんにちは、世界。 \nこんばんは、世界。"
        output_text = "こ ん に ち は 、 世 界 。 こ ん ば ん は 、 世 界 。"
225
226
        return input_text, output_text

227
228
229
230
231
232
233
234
235
    def test_pretokenized_inputs(self):
        pass  # TODO add if relevant

    def test_maximum_encoding_length_pair_input(self):
        pass  # TODO add if relevant

    def test_maximum_encoding_length_single_input(self):
        pass  # TODO add if relevant

236
    def test_full_tokenizer(self):
237
        tokenizer = self.tokenizer_class(self.vocab_file, subword_tokenizer_type="character")
238

239
240
241
242
243
244
245
        tokens = tokenizer.tokenize("こんにちは、世界。 \nこんばんは、世界。")
        self.assertListEqual(
            tokens, ["こ", "ん", "に", "ち", "は", "、", "世", "界", "。", "こ", "ん", "ば", "ん", "は", "、", "世", "界", "。"]
        )
        self.assertListEqual(
            tokenizer.convert_tokens_to_ids(tokens), [3, 4, 5, 6, 7, 11, 9, 10, 12, 3, 4, 8, 4, 7, 11, 9, 10, 12]
        )
246
247

    def test_character_tokenizer(self):
248
        vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "こ", "ん", "に", "ち", "は", "ば", "世", "界" "、", "。"]
249
250
251
252

        vocab = {}
        for (i, token) in enumerate(vocab_tokens):
            vocab[token] = i
253
        tokenizer = CharacterTokenizer(vocab=vocab, unk_token="[UNK]")
254

255
        self.assertListEqual(tokenizer.tokenize(""), [])
256

257
        self.assertListEqual(tokenizer.tokenize("こんにちは"), ["こ", "ん", "に", "ち", "は"])
258

259
        self.assertListEqual(tokenizer.tokenize("こんにちほ"), ["こ", "ん", "に", "ち", "[UNK]"])
260
261

    def test_sequence_builders(self):
262
        tokenizer = self.tokenizer_class.from_pretrained("cl-tohoku/bert-base-japanese-char")
263

264
265
        text = tokenizer.encode("ありがとう。", add_special_tokens=False)
        text_2 = tokenizer.encode("どういたしまして。", add_special_tokens=False)
266
267
268
269
270
271
272

        encoded_sentence = tokenizer.build_inputs_with_special_tokens(text)
        encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2)

        # 2 is for "[CLS]", 3 is for "[SEP]"
        assert encoded_sentence == [2] + text + [3]
        assert encoded_pair == [2] + text + [3] + text_2 + [3]
273
274
275
276
277
278
279
280


@custom_tokenizers
class AutoTokenizerCustomTest(unittest.TestCase):
    def test_tokenizer_bert_japanese(self):
        EXAMPLE_BERT_JAPANESE_ID = "cl-tohoku/bert-base-japanese"
        tokenizer = AutoTokenizer.from_pretrained(EXAMPLE_BERT_JAPANESE_ID)
        self.assertIsInstance(tokenizer, BertJapaneseTokenizer)