tokenization_test.py 4.17 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
thomwolf's avatar
thomwolf committed
15
from __future__ import absolute_import, division, print_function, unicode_literals
thomwolf's avatar
thomwolf committed
16
17

import os
18
import unittest
thomwolf's avatar
thomwolf committed
19
from io import open
thomwolf's avatar
thomwolf committed
20

thomwolf's avatar
thomwolf committed
21
22
23
24
25
from pytorch_pretrained_bert.tokenization import (BasicTokenizer,
                                                  BertTokenizer,
                                                  WordpieceTokenizer,
                                                  _is_control, _is_punctuation,
                                                  _is_whitespace)
thomwolf's avatar
thomwolf committed
26
27


28
class TokenizationTest(unittest.TestCase):
thomwolf's avatar
thomwolf committed
29

30
31
32
33
34
    def test_full_tokenizer(self):
        vocab_tokens = [
            "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
            "##ing", ","
        ]
thomwolf's avatar
thomwolf committed
35
        with open("/tmp/bert_tokenizer_test.txt", "w", encoding='utf-8') as vocab_writer:
36
            vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
thomwolf's avatar
thomwolf committed
37

38
            vocab_file = vocab_writer.name
thomwolf's avatar
thomwolf committed
39

thomwolf's avatar
thomwolf committed
40
        tokenizer = BertTokenizer(vocab_file)
41
        os.remove(vocab_file)
thomwolf's avatar
thomwolf committed
42

43
        tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
44
        self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
thomwolf's avatar
thomwolf committed
45

46
        self.assertListEqual(
47
            tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
thomwolf's avatar
thomwolf committed
48

49
    def test_chinese(self):
thomwolf's avatar
thomwolf committed
50
        tokenizer = BasicTokenizer()
51

52
53
        self.assertListEqual(
            tokenizer.tokenize(u"ah\u535A\u63A8zz"),
54
            [u"ah", u"\u535A", u"\u63A8", u"zz"])
55

56
    def test_basic_tokenizer_lower(self):
thomwolf's avatar
thomwolf committed
57
        tokenizer = BasicTokenizer(do_lower_case=True)
thomwolf's avatar
thomwolf committed
58

59
        self.assertListEqual(
60
61
            tokenizer.tokenize(u" \tHeLLo!how  \n Are yoU?  "),
            ["hello", "!", "how", "are", "you", "?"])
62
        self.assertListEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"])
thomwolf's avatar
thomwolf committed
63

64
    def test_basic_tokenizer_no_lower(self):
thomwolf's avatar
thomwolf committed
65
        tokenizer = BasicTokenizer(do_lower_case=False)
thomwolf's avatar
thomwolf committed
66

67
        self.assertListEqual(
68
69
            tokenizer.tokenize(u" \tHeLLo!how  \n Are yoU?  "),
            ["HeLLo", "!", "how", "Are", "yoU", "?"])
thomwolf's avatar
thomwolf committed
70

71
72
73
74
75
    def test_wordpiece_tokenizer(self):
        vocab_tokens = [
            "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
            "##ing"
        ]
thomwolf's avatar
thomwolf committed
76

77
78
79
        vocab = {}
        for (i, token) in enumerate(vocab_tokens):
            vocab[token] = i
thomwolf's avatar
thomwolf committed
80
        tokenizer = WordpieceTokenizer(vocab=vocab)
thomwolf's avatar
thomwolf committed
81

82
        self.assertListEqual(tokenizer.tokenize(""), [])
thomwolf's avatar
thomwolf committed
83

84
        self.assertListEqual(
85
86
            tokenizer.tokenize("unwanted running"),
            ["un", "##want", "##ed", "runn", "##ing"])
thomwolf's avatar
thomwolf committed
87

88
        self.assertListEqual(
89
            tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"])
thomwolf's avatar
thomwolf committed
90

91
    def test_is_whitespace(self):
thomwolf's avatar
thomwolf committed
92
93
94
95
96
        self.assertTrue(_is_whitespace(u" "))
        self.assertTrue(_is_whitespace(u"\t"))
        self.assertTrue(_is_whitespace(u"\r"))
        self.assertTrue(_is_whitespace(u"\n"))
        self.assertTrue(_is_whitespace(u"\u00A0"))
thomwolf's avatar
thomwolf committed
97

thomwolf's avatar
thomwolf committed
98
99
        self.assertFalse(_is_whitespace(u"A"))
        self.assertFalse(_is_whitespace(u"-"))
thomwolf's avatar
thomwolf committed
100

101
    def test_is_control(self):
thomwolf's avatar
thomwolf committed
102
        self.assertTrue(_is_control(u"\u0005"))
thomwolf's avatar
thomwolf committed
103

thomwolf's avatar
thomwolf committed
104
105
106
107
        self.assertFalse(_is_control(u"A"))
        self.assertFalse(_is_control(u" "))
        self.assertFalse(_is_control(u"\t"))
        self.assertFalse(_is_control(u"\r"))
thomwolf's avatar
thomwolf committed
108

109
    def test_is_punctuation(self):
thomwolf's avatar
thomwolf committed
110
111
112
113
        self.assertTrue(_is_punctuation(u"-"))
        self.assertTrue(_is_punctuation(u"$"))
        self.assertTrue(_is_punctuation(u"`"))
        self.assertTrue(_is_punctuation(u"."))
thomwolf's avatar
thomwolf committed
114

thomwolf's avatar
thomwolf committed
115
116
        self.assertFalse(_is_punctuation(u"A"))
        self.assertFalse(_is_punctuation(u" "))
thomwolf's avatar
thomwolf committed
117
118


119
120
if __name__ == '__main__':
    unittest.main()