tokenization_xlm_roberta.py 9.54 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
""" Tokenization classes for XLM-RoBERTa model."""
Aymeric Augustin's avatar
Aymeric Augustin committed
16

17
18
19
20
21
22

import logging
import os
from shutil import copyfile

from transformers.tokenization_utils import PreTrainedTokenizer
Aymeric Augustin's avatar
Aymeric Augustin committed
23

thomwolf's avatar
thomwolf committed
24
from .tokenization_xlnet import SPIECE_UNDERLINE
25

Aymeric Augustin's avatar
Aymeric Augustin committed
26

27
28
logger = logging.getLogger(__name__)

29
VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model"}
30
31

PRETRAINED_VOCAB_FILES_MAP = {
32
33
34
35
36
37
38
    "vocab_file": {
        "xlm-roberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-base-sentencepiece.bpe.model",
        "xlm-roberta-large": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-sentencepiece.bpe.model",
        "xlm-roberta-large-finetuned-conll02-dutch": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-dutch-sentencepiece.bpe.model",
        "xlm-roberta-large-finetuned-conll02-spanish": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-spanish-sentencepiece.bpe.model",
        "xlm-roberta-large-finetuned-conll03-english": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-english-sentencepiece.bpe.model",
        "xlm-roberta-large-finetuned-conll03-german": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-german-sentencepiece.bpe.model",
39
40
41
42
    }
}

PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
43
44
45
46
47
48
    "xlm-roberta-base": 512,
    "xlm-roberta-large": 512,
    "xlm-roberta-large-finetuned-conll02-dutch": 512,
    "xlm-roberta-large-finetuned-conll02-spanish": 512,
    "xlm-roberta-large-finetuned-conll03-english": 512,
    "xlm-roberta-large-finetuned-conll03-german": 512,
49
50
}

51

52
53
54
55
56
57
58
class XLMRobertaTokenizer(PreTrainedTokenizer):
    """
        Adapted from RobertaTokenizer and XLNetTokenizer
        SentencePiece based tokenizer. Peculiarities:

            - requires `SentencePiece <https://github.com/google/sentencepiece>`_
    """
59

60
61
62
63
    vocab_files_names = VOCAB_FILES_NAMES
    pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
    max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES

64
65
66
67
68
69
70
71
72
73
74
75
    def __init__(
        self,
        vocab_file,
        bos_token="<s>",
        eos_token="</s>",
        sep_token="</s>",
        cls_token="<s>",
        unk_token="<unk>",
        pad_token="<pad>",
        mask_token="<mask>",
        **kwargs
    ):
Julien Chaumond's avatar
Julien Chaumond committed
76
        super().__init__(
77
78
79
80
81
82
83
            bos_token=bos_token,
            eos_token=eos_token,
            unk_token=unk_token,
            sep_token=sep_token,
            cls_token=cls_token,
            pad_token=pad_token,
            mask_token=mask_token,
84
            **kwargs,
85
        )
86
87
        self.max_len_single_sentence = self.max_len - 2  # take into account special tokens
        self.max_len_sentences_pair = self.max_len - 4  # take into account special tokens
88
89
90
91
92
93
94
95
96
97

        try:
            import sentencepiece as spm
        except ImportError:
            logger.warning(
                "You need to install SentencePiece to use XLMRobertaTokenizer: https://github.com/google/sentencepiece"
                "pip install sentencepiece"
            )
            raise

98
99
100
        self.sp_model = spm.SentencePieceProcessor()
        self.sp_model.Load(str(vocab_file))
        self.vocab_file = vocab_file
101
102
103
104
105
106
107
108
109
110
111
112
113

        # Original fairseq vocab and spm vocab must be "aligned":
        # Vocab    |    0    |    1    |   2    |    3    |  4  |  5  |  6  |   7   |   8   |  9
        # -------- | ------- | ------- | ------ | ------- | --- | --- | --- | ----- | ----- | ----
        # fairseq  | '<s>'   | '<pad>' | '</s>' | '<unk>' | ',' | '.' | '▁' | 's'   | '▁de' | '-'
        # spm      | '<unk>' | '<s>'   | '</s>' | ','     | '.' | '▁' | 's' | '▁de' | '-'   | '▁a'

        # Mimic fairseq token-to-id alignment for the first 4 token
        self.fairseq_tokens_to_ids = {"<s>": 0, "<pad>": 1, "</s>": 2, "<unk>": 3}

        # The first "real" token "," has position 4 in the original fairseq vocab and position 3 in the spm vocab
        self.fairseq_offset = 1

114
        self.fairseq_tokens_to_ids["<mask>"] = len(self.sp_model) + len(self.fairseq_tokens_to_ids)
115
116
        self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}

117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
    def __getstate__(self):
        state = self.__dict__.copy()
        state["sp_model"] = None
        return state

    def __setstate__(self, d):
        self.__dict__ = d
        try:
            import sentencepiece as spm
        except ImportError:
            logger.warning(
                "You need to install SentencePiece to use XLMRobertaTokenizer: https://github.com/google/sentencepiece"
                "pip install sentencepiece"
            )
            raise
        self.sp_model = spm.SentencePieceProcessor()
        self.sp_model.Load(self.vocab_file)

135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
    def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
        """
        Build model inputs from a sequence or a pair of sequence for sequence classification tasks
        by concatenating and adding special tokens.
        A RoBERTa sequence has the following format:
            single sequence: <s> X </s>
            pair of sequences: <s> A </s></s> B </s>
        """
        if token_ids_1 is None:
            return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
        cls = [self.cls_token_id]
        sep = [self.sep_token_id]
        return cls + token_ids_0 + sep + sep + token_ids_1 + sep

    def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False):
        """
        Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
        special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.

        Args:
            token_ids_0: list of ids (must not contain special tokens)
            token_ids_1: Optional list of ids (must not contain special tokens), necessary when fetching sequence ids
                for sequence pairs
            already_has_special_tokens: (default False) Set to True if the token list is already formated with
                special tokens for the model

        Returns:
            A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
        """
        if already_has_special_tokens:
            if token_ids_1 is not None:
166
167
168
169
                raise ValueError(
                    "You should not supply a second sequence if the provided sequence of "
                    "ids is already formated with special tokens for the model."
                )
170
171
172
173
174
175
176
177
178
            return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))

        if token_ids_1 is None:
            return [1] + ([0] * len(token_ids_0)) + [1]
        return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]

    def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None):
        """
        Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
Maksym Del's avatar
Maksym Del committed
179
        RoBERTa does not make use of token type ids, therefore a list of zeros is returned.
180
181
182
183
184
185
186
        if token_ids_1 is None, only returns the first portion of the mask (0's).
        """
        sep = [self.sep_token_id]
        cls = [self.cls_token_id]

        if token_ids_1 is None:
            return len(cls + token_ids_0 + sep) * [0]
Maksym Del's avatar
Maksym Del committed
187
        return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
188
189
190
191
192
193
194
195
196

    @property
    def vocab_size(self):
        return len(self.sp_model) + len(self.fairseq_tokens_to_ids)

    def _tokenize(self, text):
        return self.sp_model.EncodeAsPieces(text)

    def _convert_token_to_id(self, token):
Aymeric Augustin's avatar
Aymeric Augustin committed
197
        """ Converts a token (str) in an id using the vocab. """
198
199
        if token in self.fairseq_tokens_to_ids:
            return self.fairseq_tokens_to_ids[token]
200
        return self.sp_model.PieceToId(token) + self.fairseq_offset
201
202

    def _convert_id_to_token(self, index):
Aymeric Augustin's avatar
Aymeric Augustin committed
203
        """Converts an index (integer) in a token (str) using the vocab."""
204
205
        if index in self.fairseq_ids_to_tokens:
            return self.fairseq_ids_to_tokens[index]
206
        return self.sp_model.IdToPiece(index - self.fairseq_offset)
207

thomwolf's avatar
thomwolf committed
208
209
    def convert_tokens_to_string(self, tokens):
        """Converts a sequence of tokens (strings for sub-words) in a single string."""
210
        out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
thomwolf's avatar
thomwolf committed
211
212
        return out_string

213
214
215
216
217
218
219
    def save_vocabulary(self, save_directory):
        """ Save the sentencepiece vocabulary (copy original file) and special tokens file
            to a directory.
        """
        if not os.path.isdir(save_directory):
            logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
            return
220
        out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["vocab_file"])
221
222
223
224
225

        if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
            copyfile(self.vocab_file, out_vocab_file)

        return (out_vocab_file,)