tokenization_test.py 4.97 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
thomwolf's avatar
thomwolf committed
15
from __future__ import absolute_import, division, print_function, unicode_literals
thomwolf's avatar
thomwolf committed
16
17

import os
18
import unittest
thomwolf's avatar
thomwolf committed
19
from io import open
20
21
import shutil
import pytest
thomwolf's avatar
thomwolf committed
22

thomwolf's avatar
thomwolf committed
23
24
25
26
from pytorch_pretrained_bert.tokenization import (BasicTokenizer,
                                                  BertTokenizer,
                                                  WordpieceTokenizer,
                                                  _is_control, _is_punctuation,
27
                                                  _is_whitespace, PRETRAINED_VOCAB_ARCHIVE_MAP)
thomwolf's avatar
thomwolf committed
28
29


30
class TokenizationTest(unittest.TestCase):
thomwolf's avatar
thomwolf committed
31

32
33
34
35
36
    def test_full_tokenizer(self):
        vocab_tokens = [
            "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
            "##ing", ","
        ]
thomwolf's avatar
thomwolf committed
37
        with open("/tmp/bert_tokenizer_test.txt", "w", encoding='utf-8') as vocab_writer:
38
            vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
thomwolf's avatar
thomwolf committed
39

40
            vocab_file = vocab_writer.name
thomwolf's avatar
thomwolf committed
41

thomwolf's avatar
thomwolf committed
42
        tokenizer = BertTokenizer(vocab_file)
43
        os.remove(vocab_file)
thomwolf's avatar
thomwolf committed
44

45
        tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
46
        self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
thomwolf's avatar
thomwolf committed
47

48
        self.assertListEqual(
49
            tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
thomwolf's avatar
thomwolf committed
50

51
52
53
54
55
56
57
58
59
60
        vocab_file = tokenizer.save_vocabulary(vocab_path="/tmp/")
        tokenizer.from_pretrained(vocab_file)
        os.remove(vocab_file)

        tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
        self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])

        self.assertListEqual(
            tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])

61
62
63
64
65
66
67
    @pytest.mark.slow
    def test_tokenizer_from_pretrained(self):
        cache_dir = "/tmp/pytorch_pretrained_bert_test/"
        for model_name in list(PRETRAINED_VOCAB_ARCHIVE_MAP.keys())[:1]:
            tokenizer = BertTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
            shutil.rmtree(cache_dir)
            self.assertIsNotNone(tokenizer)
68

69
    def test_chinese(self):
thomwolf's avatar
thomwolf committed
70
        tokenizer = BasicTokenizer()
71

72
73
        self.assertListEqual(
            tokenizer.tokenize(u"ah\u535A\u63A8zz"),
74
            [u"ah", u"\u535A", u"\u63A8", u"zz"])
75

76
    def test_basic_tokenizer_lower(self):
thomwolf's avatar
thomwolf committed
77
        tokenizer = BasicTokenizer(do_lower_case=True)
thomwolf's avatar
thomwolf committed
78

79
        self.assertListEqual(
80
81
            tokenizer.tokenize(u" \tHeLLo!how  \n Are yoU?  "),
            ["hello", "!", "how", "are", "you", "?"])
82
        self.assertListEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"])
thomwolf's avatar
thomwolf committed
83

84
    def test_basic_tokenizer_no_lower(self):
thomwolf's avatar
thomwolf committed
85
        tokenizer = BasicTokenizer(do_lower_case=False)
thomwolf's avatar
thomwolf committed
86

87
        self.assertListEqual(
88
89
            tokenizer.tokenize(u" \tHeLLo!how  \n Are yoU?  "),
            ["HeLLo", "!", "how", "Are", "yoU", "?"])
thomwolf's avatar
thomwolf committed
90

91
92
93
94
95
    def test_wordpiece_tokenizer(self):
        vocab_tokens = [
            "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
            "##ing"
        ]
thomwolf's avatar
thomwolf committed
96

97
98
99
        vocab = {}
        for (i, token) in enumerate(vocab_tokens):
            vocab[token] = i
thomwolf's avatar
thomwolf committed
100
        tokenizer = WordpieceTokenizer(vocab=vocab)
thomwolf's avatar
thomwolf committed
101

102
        self.assertListEqual(tokenizer.tokenize(""), [])
thomwolf's avatar
thomwolf committed
103

104
        self.assertListEqual(
105
106
            tokenizer.tokenize("unwanted running"),
            ["un", "##want", "##ed", "runn", "##ing"])
thomwolf's avatar
thomwolf committed
107

108
        self.assertListEqual(
109
            tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"])
thomwolf's avatar
thomwolf committed
110

111
    def test_is_whitespace(self):
thomwolf's avatar
thomwolf committed
112
113
114
115
116
        self.assertTrue(_is_whitespace(u" "))
        self.assertTrue(_is_whitespace(u"\t"))
        self.assertTrue(_is_whitespace(u"\r"))
        self.assertTrue(_is_whitespace(u"\n"))
        self.assertTrue(_is_whitespace(u"\u00A0"))
thomwolf's avatar
thomwolf committed
117

thomwolf's avatar
thomwolf committed
118
119
        self.assertFalse(_is_whitespace(u"A"))
        self.assertFalse(_is_whitespace(u"-"))
thomwolf's avatar
thomwolf committed
120

121
    def test_is_control(self):
thomwolf's avatar
thomwolf committed
122
        self.assertTrue(_is_control(u"\u0005"))
thomwolf's avatar
thomwolf committed
123

thomwolf's avatar
thomwolf committed
124
125
126
127
        self.assertFalse(_is_control(u"A"))
        self.assertFalse(_is_control(u" "))
        self.assertFalse(_is_control(u"\t"))
        self.assertFalse(_is_control(u"\r"))
thomwolf's avatar
thomwolf committed
128

129
    def test_is_punctuation(self):
thomwolf's avatar
thomwolf committed
130
131
132
133
        self.assertTrue(_is_punctuation(u"-"))
        self.assertTrue(_is_punctuation(u"$"))
        self.assertTrue(_is_punctuation(u"`"))
        self.assertTrue(_is_punctuation(u"."))
thomwolf's avatar
thomwolf committed
134

thomwolf's avatar
thomwolf committed
135
136
        self.assertFalse(_is_punctuation(u"A"))
        self.assertFalse(_is_punctuation(u" "))
thomwolf's avatar
thomwolf committed
137
138


139
140
if __name__ == '__main__':
    unittest.main()