tokenization_roberta.py 8.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# coding=utf-8
# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for RoBERTa."""
from __future__ import (absolute_import, division, print_function,
                        unicode_literals)

import json
import logging
import re
Julien Chaumond's avatar
Julien Chaumond committed
22
23
from io import open
import six
24
import os
25

26
from .tokenization_utils import PreTrainedTokenizer
27
28
29
30
from .tokenization_gpt2 import GPT2Tokenizer

logger = logging.getLogger(__name__)

31
32
DICT_FILES_NAMES = {
    'dict_file': 'dict.txt',
33
34
}

35
36
PRETRAINED_DICT_FILES_MAP = {
    'dict_file':
37
38
39
40
41
        {
            'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-dict.txt",
            'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-dict.txt",
            'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-dict.txt",
        },
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
}

PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
    'roberta-base': 512,
    'roberta-large': 512,
    'roberta-large-mnli': 512,
}

SPACE_NORMALIZER = re.compile(r"\s+")

def tokenize_line(line):
    line = SPACE_NORMALIZER.sub(" ", line)
    line = line.strip()
    return line.split()


class Dictionary(object):
    """
    A mapping from symbols to consecutive integers

    From Facebook's fairseq.
    """

    def __init__(
        self,
        pad='<pad>',
        eos='</s>',
        unk='<unk>',
        bos='<s>',
        extra_special_symbols=None,
    ):
        self.unk_word, self.pad_word, self.eos_word = unk, pad, eos
        self.symbols = []
        self.count = []
        self.indices = {}
        self.bos_index = self.add_symbol(bos)
        self.pad_index = self.add_symbol(pad)
        self.eos_index = self.add_symbol(eos)
        self.unk_index = self.add_symbol(unk)
        if extra_special_symbols:
            for s in extra_special_symbols:
                self.add_symbol(s)
        self.nspecial = len(self.symbols)

    def __getitem__(self, idx):
        if idx < len(self.symbols):
            return self.symbols[idx]
        return self.unk_word

    def index(self, sym):
        """Returns the index of the specified symbol"""
        assert isinstance(sym, str)
        if sym in self.indices:
            return self.indices[sym]
        return self.unk_index

    def add_symbol(self, word, n=1):
        """Adds a word to the dictionary"""
        if word in self.indices:
            idx = self.indices[word]
            self.count[idx] = self.count[idx] + n
            return idx
        else:
            idx = len(self.symbols)
            self.indices[word] = idx
            self.symbols.append(word)
            self.count.append(n)
            return idx

    @classmethod
    def load(cls, f, ignore_utf_errors=False):
        """Loads the dictionary from a text file with the format:

        ```
        <symbol0> <count0>
        <symbol1> <count1>
        ...
        ```
        """
        d = cls()
        d.add_from_file(f, ignore_utf_errors)
        return d

    def add_from_file(self, f, ignore_utf_errors=False):
        """
        Loads a pre-existing dictionary from a text file and adds its symbols
        to this instance.
        """
Julien Chaumond's avatar
Julien Chaumond committed
130
        if isinstance(f, six.string_types):
131
132
133
134
135
136
137
138
139
140
141
142
143
144
            try:
                if not ignore_utf_errors:
                    with open(f, 'r', encoding='utf-8') as fd:
                        self.add_from_file(fd)
                else:
                    with open(f, 'r', encoding='utf-8', errors='ignore') as fd:
                        self.add_from_file(fd)
            except FileNotFoundError as fnfe:
                raise fnfe
            except UnicodeError:
                raise Exception("Incorrect encoding detected in {}, please "
                                "rebuild the dataset".format(f))
            return

145
        lines = f.read().splitlines()
146
147
148
149
150
151
152
153
154
        for line in lines:
            idx = line.rfind(' ')
            if idx == -1:
                raise ValueError("Incorrect dictionary format, expected '<token> <cnt>'")
            word = line[:idx]
            count = int(line[idx + 1:])
            self.indices[word] = len(self.symbols)
            self.symbols.append(word)
            self.count.append(count)
155

156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
    def encode_line(self, line, line_tokenizer=tokenize_line, add_if_not_exist=True,
                    consumer=None, append_eos=True, reverse_order=False):
        words = line_tokenizer(line)
        if reverse_order:
            words = list(reversed(words))
        nwords = len(words)
        ids = [0] * (nwords + 1 if append_eos else nwords)

        for i, word in enumerate(words):
            if add_if_not_exist:
                idx = self.add_symbol(word)
            else:
                idx = self.index(word)
            if consumer is not None:
                consumer(word, idx)
            ids[i] = idx
        if append_eos:
            ids[nwords] = self.eos_index
        return ids


class RobertaTokenizer(PreTrainedTokenizer):
    """
    RoBERTa tokenizer. Peculiarities:
        - GPT-2 tokenizer with a different integer mapping on top.
    """
182
183
    vocab_files_names = DICT_FILES_NAMES
    pretrained_vocab_files_map = PRETRAINED_DICT_FILES_MAP
184
185
    max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES

186
187
188
189
    def __init__(self, dict_file, bpe_tokenizer=None, bos_token="<s>", eos_token="</s>", sep_token="</s>", cls_token="<s>",
                 unk_token="<unk>", **kwargs):
        super(RobertaTokenizer, self).__init__(cls_token=bos_token, sep_token=eos_token, eos_token=eos_token,
                                               unk_token=unk_token, **kwargs)
190

191
192
193
194
195
196
        self.gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2") if bpe_tokenizer is None else bpe_tokenizer
        self.dictionary = Dictionary.load(dict_file)

    @property
    def vocab_size(self):
        return len(self.dictionary.indices)
197
198
199
200
201
202

    def _tokenize(self, text):
        """ Use GPT-2 Tokenizer """
        return self.gpt2_tokenizer._tokenize(text)

    def _convert_token_to_id(self, token):
203
204
205
        if self.dictionary.index(token) != 3:
            return self.dictionary.index(token)
        return self.dictionary.index(str(self.gpt2_tokenizer.convert_tokens_to_ids(token)))
206
207
208
209
210
211

    def _convert_id_to_token(self, index):
        symbol = self.dictionary[index]
        try:
            idx = int(symbol)
            return self.gpt2_tokenizer._convert_id_to_token(idx)
212
        except ValueError:
213
214
215
216
            return symbol

    def convert_tokens_to_string(self, tokens):
        return self.gpt2_tokenizer.convert_tokens_to_string(tokens)
217

218
219
220
221
222
223
    def convert_tokens_to_ids(self, tokens, no_sep_cls_tokens=False):
        cls = [self._convert_token_to_id(self.cls_token)]
        tokens = super().convert_tokens_to_ids(tokens)
        sep = [self._convert_token_to_id(self.sep_token)]
        return (cls + tokens + sep) if (isinstance(tokens, list) and not no_sep_cls_tokens) else tokens

224
    def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
225
        return super().convert_ids_to_tokens(ids, skip_special_tokens=skip_special_tokens)[1:-1]
226

227
228
229
230
231
232
233
234
235
236
    def save_vocabulary(self, save_directory):
        """Save the tokenizer vocabulary and merge files to a directory."""
        if not os.path.isdir(save_directory):
            logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
            return
        dict_file = os.path.join(save_directory, DICT_FILES_NAMES['dict_file'])

        with open(dict_file, 'w', encoding='utf-8') as f:
            for i in range(self.dictionary.nspecial, len(self.dictionary.count)):
                f.write(f"{list(self.dictionary.indices.keys())[i]} {self.dictionary.count[i]}\n")
237

238
        vocab_files = self.gpt2_tokenizer.save_pretrained(save_directory)
239

240
        return vocab_files + (dict_file,)