test_tokenization_bert.py 4.97 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Aymeric Augustin's avatar
Aymeric Augustin committed
15

thomwolf's avatar
thomwolf committed
16
17

import os
18
import unittest
thomwolf's avatar
thomwolf committed
19
from io import open
thomwolf's avatar
thomwolf committed
20

21
from transformers.tokenization_bert import (
Aymeric Augustin's avatar
Aymeric Augustin committed
22
    VOCAB_FILES_NAMES,
23
24
25
26
27
28
29
    BasicTokenizer,
    BertTokenizer,
    WordpieceTokenizer,
    _is_control,
    _is_punctuation,
    _is_whitespace,
)
thomwolf's avatar
thomwolf committed
30

31
from .test_tokenization_common import TokenizerTesterMixin
32
from .utils import slow
thomwolf's avatar
thomwolf committed
33

34

35
class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
36
37
38
39
40

    tokenizer_class = BertTokenizer

    def setUp(self):
        super(BertTokenizationTest, self).setUp()
thomwolf's avatar
thomwolf committed
41

42
        vocab_tokens = [
43
44
45
46
47
48
49
50
51
52
53
54
55
            "[UNK]",
            "[CLS]",
            "[SEP]",
            "want",
            "##want",
            "##ed",
            "wa",
            "un",
            "runn",
            "##ing",
            ",",
            "low",
            "lowest",
56
        ]
57
58
        self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
        with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
59
            vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
thomwolf's avatar
thomwolf committed
60

61
62
    def get_tokenizer(self, **kwargs):
        return BertTokenizer.from_pretrained(self.tmpdirname, **kwargs)
63

64
    def get_input_output_texts(self):
65
66
        input_text = "UNwant\u00E9d,running"
        output_text = "unwanted, running"
67
        return input_text, output_text
68

69
    def test_full_tokenizer(self):
thomwolf's avatar
thomwolf committed
70
        tokenizer = self.tokenizer_class(self.vocab_file)
thomwolf's avatar
thomwolf committed
71

72
        tokens = tokenizer.tokenize("UNwant\u00E9d,running")
73
74
        self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
        self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
75

76
    def test_chinese(self):
thomwolf's avatar
thomwolf committed
77
        tokenizer = BasicTokenizer()
78

79
        self.assertListEqual(tokenizer.tokenize("ah\u535A\u63A8zz"), ["ah", "\u535A", "\u63A8", "zz"])
80

81
    def test_basic_tokenizer_lower(self):
thomwolf's avatar
thomwolf committed
82
        tokenizer = BasicTokenizer(do_lower_case=True)
thomwolf's avatar
thomwolf committed
83

84
        self.assertListEqual(
85
86
87
            tokenizer.tokenize(" \tHeLLo!how  \n Are yoU?  "), ["hello", "!", "how", "are", "you", "?"]
        )
        self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["hello"])
thomwolf's avatar
thomwolf committed
88

89
    def test_basic_tokenizer_no_lower(self):
thomwolf's avatar
thomwolf committed
90
        tokenizer = BasicTokenizer(do_lower_case=False)
thomwolf's avatar
thomwolf committed
91

92
        self.assertListEqual(
93
94
            tokenizer.tokenize(" \tHeLLo!how  \n Are yoU?  "), ["HeLLo", "!", "how", "Are", "yoU", "?"]
        )
thomwolf's avatar
thomwolf committed
95

96
    def test_wordpiece_tokenizer(self):
97
        vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", "##ing"]
thomwolf's avatar
thomwolf committed
98

99
100
101
        vocab = {}
        for (i, token) in enumerate(vocab_tokens):
            vocab[token] = i
102
        tokenizer = WordpieceTokenizer(vocab=vocab, unk_token="[UNK]")
thomwolf's avatar
thomwolf committed
103

104
        self.assertListEqual(tokenizer.tokenize(""), [])
thomwolf's avatar
thomwolf committed
105

106
        self.assertListEqual(tokenizer.tokenize("unwanted running"), ["un", "##want", "##ed", "runn", "##ing"])
thomwolf's avatar
thomwolf committed
107

108
        self.assertListEqual(tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"])
thomwolf's avatar
thomwolf committed
109

110
    def test_is_whitespace(self):
111
112
113
114
115
        self.assertTrue(_is_whitespace(" "))
        self.assertTrue(_is_whitespace("\t"))
        self.assertTrue(_is_whitespace("\r"))
        self.assertTrue(_is_whitespace("\n"))
        self.assertTrue(_is_whitespace("\u00A0"))
thomwolf's avatar
thomwolf committed
116

117
118
        self.assertFalse(_is_whitespace("A"))
        self.assertFalse(_is_whitespace("-"))
thomwolf's avatar
thomwolf committed
119

120
    def test_is_control(self):
121
        self.assertTrue(_is_control("\u0005"))
thomwolf's avatar
thomwolf committed
122

123
124
125
126
        self.assertFalse(_is_control("A"))
        self.assertFalse(_is_control(" "))
        self.assertFalse(_is_control("\t"))
        self.assertFalse(_is_control("\r"))
thomwolf's avatar
thomwolf committed
127

128
    def test_is_punctuation(self):
129
130
131
132
        self.assertTrue(_is_punctuation("-"))
        self.assertTrue(_is_punctuation("$"))
        self.assertTrue(_is_punctuation("`"))
        self.assertTrue(_is_punctuation("."))
thomwolf's avatar
thomwolf committed
133

134
135
        self.assertFalse(_is_punctuation("A"))
        self.assertFalse(_is_punctuation(" "))
thomwolf's avatar
thomwolf committed
136

137
    @slow
138
    def test_sequence_builders(self):
thomwolf's avatar
thomwolf committed
139
        tokenizer = self.tokenizer_class.from_pretrained("bert-base-uncased")
140

Lysandre's avatar
Remove  
Lysandre committed
141
142
        text = tokenizer.encode("sequence builders", add_special_tokens=False)
        text_2 = tokenizer.encode("multi-sequence build", add_special_tokens=False)
143

144
145
        encoded_sentence = tokenizer.build_inputs_with_special_tokens(text)
        encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2)
146
147
148

        assert encoded_sentence == [101] + text + [102]
        assert encoded_pair == [101] + text + [102] + text_2 + [102]