test_tokenization_bert_japanese.py 10.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Aymeric Augustin's avatar
Aymeric Augustin committed
15

16
17

import os
18
import pickle
19
import unittest
20

Sylvain Gugger's avatar
Sylvain Gugger committed
21
from transformers.models.bert_japanese.tokenization_bert_japanese import (
Aymeric Augustin's avatar
Aymeric Augustin committed
22
    VOCAB_FILES_NAMES,
23
24
    BertJapaneseTokenizer,
    CharacterTokenizer,
Aymeric Augustin's avatar
Aymeric Augustin committed
25
    MecabTokenizer,
26
    WordpieceTokenizer,
27
)
Sylvain Gugger's avatar
Sylvain Gugger committed
28
from transformers.testing_utils import custom_tokenizers
29

30
from .test_tokenization_common import TokenizerTesterMixin
31
32


33
@custom_tokenizers
34
class BertJapaneseTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
35
36

    tokenizer_class = BertJapaneseTokenizer
37
    space_between_special_tokens = True
38
39

    def setUp(self):
Julien Chaumond's avatar
Julien Chaumond committed
40
        super().setUp()
41

42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
        vocab_tokens = [
            "[UNK]",
            "[CLS]",
            "[SEP]",
            "こんにちは",
            "こん",
            "にちは",
            "ばんは",
            "##こん",
            "##にちは",
            "##ばんは",
            "世界",
            "##世界",
            "、",
            "##、",
            "。",
            "##。",
        ]
60
61
62
63
64

        self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
        with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
            vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))

65
    def get_input_output_texts(self, tokenizer):
66
67
        input_text = "こんにちは、世界。 \nこんばんは、世界。"
        output_text = "こんにちは 、 世界 。 こんばんは 、 世界 。"
68
69
        return input_text, output_text

70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
    def get_clean_sequence(self, tokenizer):
        input_text, output_text = self.get_input_output_texts(tokenizer)
        ids = tokenizer.encode(output_text, add_special_tokens=False)
        text = tokenizer.decode(ids, clean_up_tokenization_spaces=False)
        return text, ids

    def test_pretokenized_inputs(self):
        pass  # TODO add if relevant

    def test_maximum_encoding_length_pair_input(self):
        pass  # TODO add if relevant

    def test_maximum_encoding_length_single_input(self):
        pass  # TODO add if relevant

85
86
87
    def test_full_tokenizer(self):
        tokenizer = self.tokenizer_class(self.vocab_file)

88
89
90
        tokens = tokenizer.tokenize("こんにちは、世界。\nこんばんは、世界。")
        self.assertListEqual(tokens, ["こんにちは", "、", "世界", "。", "こん", "##ばんは", "、", "世界", "。"])
        self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [3, 12, 10, 14, 4, 9, 12, 10, 14])
91

92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
    def test_pickle_mecab_tokenizer(self):
        tokenizer = self.tokenizer_class(self.vocab_file, word_tokenizer_type="mecab")
        self.assertIsNotNone(tokenizer)

        text = "こんにちは、世界。\nこんばんは、世界。"
        tokens = tokenizer.tokenize(text)
        self.assertListEqual(tokens, ["こんにちは", "、", "世界", "。", "こん", "##ばんは", "、", "世界", "。"])
        self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [3, 12, 10, 14, 4, 9, 12, 10, 14])

        filename = os.path.join(self.tmpdirname, "tokenizer.bin")
        with open(filename, "wb") as handle:
            pickle.dump(tokenizer, handle)

        with open(filename, "rb") as handle:
            tokenizer_new = pickle.load(handle)

        tokens_loaded = tokenizer_new.tokenize(text)

        self.assertListEqual(tokens, tokens_loaded)

112
113
    def test_mecab_tokenizer_ipadic(self):
        tokenizer = MecabTokenizer(mecab_dic="ipadic")
114
115

        self.assertListEqual(
116
117
118
            tokenizer.tokenize(" \tアップルストアでiPhone8 が  \n 発売された 。  "),
            ["アップルストア", "で", "iPhone", "8", "が", "発売", "さ", "れ", "た", "。"],
        )
119

120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
    def test_mecab_tokenizer_unidic_lite(self):
        try:
            tokenizer = MecabTokenizer(mecab_dic="unidic_lite")
        except ModuleNotFoundError:
            return

        self.assertListEqual(
            tokenizer.tokenize(" \tアップルストアでiPhone8 が  \n 発売された 。  "),
            ["アップル", "ストア", "で", "iPhone", "8", "が", "発売", "さ", "れ", "た", "。"],
        )

    def test_mecab_tokenizer_unidic(self):
        try:
            tokenizer = MecabTokenizer(mecab_dic="unidic")
        except ModuleNotFoundError:
            return

        self.assertListEqual(
            tokenizer.tokenize(" \tアップルストアでiPhone8 が  \n 発売された 。  "),
            ["アップル", "ストア", "で", "iPhone", "8", "が", "発売", "さ", "れ", "た", "。"],
        )

142
    def test_mecab_tokenizer_lower(self):
143
        tokenizer = MecabTokenizer(do_lower_case=True, mecab_dic="ipadic")
144
145

        self.assertListEqual(
146
147
148
            tokenizer.tokenize(" \tアップルストアでiPhone8 が  \n 発売された 。  "),
            ["アップルストア", "で", "iphone", "8", "が", "発売", "さ", "れ", "た", "。"],
        )
149

150
151
152
153
154
155
156
157
158
159
160
161
162
163
    def test_mecab_tokenizer_with_option(self):
        try:
            tokenizer = MecabTokenizer(
                do_lower_case=True, normalize_text=False, mecab_option="-d /usr/local/lib/mecab/dic/jumandic"
            )
        except RuntimeError:
            # if dict doesn't exist in the system, previous code raises this error.
            return

        self.assertListEqual(
            tokenizer.tokenize(" \tアップルストアでiPhone8 が  \n 発売された 。  "),
            ["アップルストア", "で", "iPhone", "8", "が", "発売", "さ", "れた", "\u3000", "。"],
        )

164
    def test_mecab_tokenizer_no_normalize(self):
165
        tokenizer = MecabTokenizer(normalize_text=False, mecab_dic="ipadic")
166
167

        self.assertListEqual(
168
169
170
            tokenizer.tokenize(" \tアップルストアでiPhone8 が  \n 発売された 。  "),
            ["アップルストア", "で", "iPhone", "8", "が", "発売", "さ", "れ", "た", " ", "。"],
        )
171
172

    def test_wordpiece_tokenizer(self):
173
        vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "こんにちは", "こん", "にちは" "ばんは", "##こん", "##にちは", "##ばんは"]
174
175
176
177

        vocab = {}
        for (i, token) in enumerate(vocab_tokens):
            vocab[token] = i
178
        tokenizer = WordpieceTokenizer(vocab=vocab, unk_token="[UNK]")
179

180
        self.assertListEqual(tokenizer.tokenize(""), [])
181

182
        self.assertListEqual(tokenizer.tokenize("こんにちは"), ["こんにちは"])
183

184
        self.assertListEqual(tokenizer.tokenize("こんばんは"), ["こん", "##ばんは"])
185

186
        self.assertListEqual(tokenizer.tokenize("こんばんは こんばんにちは こんにちは"), ["こん", "##ばんは", "[UNK]", "こんにちは"])
187
188

    def test_sequence_builders(self):
189
        tokenizer = self.tokenizer_class.from_pretrained("cl-tohoku/bert-base-japanese")
190

191
192
        text = tokenizer.encode("ありがとう。", add_special_tokens=False)
        text_2 = tokenizer.encode("どういたしまして。", add_special_tokens=False)
193
194
195
196
197
198
199
200
201

        encoded_sentence = tokenizer.build_inputs_with_special_tokens(text)
        encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2)

        # 2 is for "[CLS]", 3 is for "[SEP]"
        assert encoded_sentence == [2] + text + [3]
        assert encoded_pair == [2] + text + [3] + text_2 + [3]


Julien Chaumond's avatar
Julien Chaumond committed
202
@custom_tokenizers
203
class BertJapaneseCharacterTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
204
205
206
207

    tokenizer_class = BertJapaneseTokenizer

    def setUp(self):
Julien Chaumond's avatar
Julien Chaumond committed
208
        super().setUp()
209

210
        vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "こ", "ん", "に", "ち", "は", "ば", "世", "界", "、", "。"]
211
212
213
214
215
216

        self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
        with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
            vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))

    def get_tokenizer(self, **kwargs):
217
        return BertJapaneseTokenizer.from_pretrained(self.tmpdirname, subword_tokenizer_type="character", **kwargs)
218

219
    def get_input_output_texts(self, tokenizer):
220
221
        input_text = "こんにちは、世界。 \nこんばんは、世界。"
        output_text = "こ ん に ち は 、 世 界 。 こ ん ば ん は 、 世 界 。"
222
223
        return input_text, output_text

224
225
226
227
228
229
230
231
232
    def test_pretokenized_inputs(self):
        pass  # TODO add if relevant

    def test_maximum_encoding_length_pair_input(self):
        pass  # TODO add if relevant

    def test_maximum_encoding_length_single_input(self):
        pass  # TODO add if relevant

233
    def test_full_tokenizer(self):
234
        tokenizer = self.tokenizer_class(self.vocab_file, subword_tokenizer_type="character")
235

236
237
238
239
240
241
242
        tokens = tokenizer.tokenize("こんにちは、世界。 \nこんばんは、世界。")
        self.assertListEqual(
            tokens, ["こ", "ん", "に", "ち", "は", "、", "世", "界", "。", "こ", "ん", "ば", "ん", "は", "、", "世", "界", "。"]
        )
        self.assertListEqual(
            tokenizer.convert_tokens_to_ids(tokens), [3, 4, 5, 6, 7, 11, 9, 10, 12, 3, 4, 8, 4, 7, 11, 9, 10, 12]
        )
243
244

    def test_character_tokenizer(self):
245
        vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "こ", "ん", "に", "ち", "は", "ば", "世", "界" "、", "。"]
246
247
248
249

        vocab = {}
        for (i, token) in enumerate(vocab_tokens):
            vocab[token] = i
250
        tokenizer = CharacterTokenizer(vocab=vocab, unk_token="[UNK]")
251

252
        self.assertListEqual(tokenizer.tokenize(""), [])
253

254
        self.assertListEqual(tokenizer.tokenize("こんにちは"), ["こ", "ん", "に", "ち", "は"])
255

256
        self.assertListEqual(tokenizer.tokenize("こんにちほ"), ["こ", "ん", "に", "ち", "[UNK]"])
257
258

    def test_sequence_builders(self):
259
        tokenizer = self.tokenizer_class.from_pretrained("cl-tohoku/bert-base-japanese-char")
260

261
262
        text = tokenizer.encode("ありがとう。", add_special_tokens=False)
        text_2 = tokenizer.encode("どういたしまして。", add_special_tokens=False)
263
264
265
266
267
268
269

        encoded_sentence = tokenizer.build_inputs_with_special_tokens(text)
        encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2)

        # 2 is for "[CLS]", 3 is for "[SEP]"
        assert encoded_sentence == [2] + text + [3]
        assert encoded_pair == [2] + text + [3] + text_2 + [3]