tokenization_utils.py 63.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# coding=utf-8
# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for OpenAI GPT."""
from __future__ import (absolute_import, division, print_function,
                        unicode_literals)

import logging
import os
21
22
import json
import six
23
import copy
24
import itertools
25
import re
26
27
from io import open

thomwolf's avatar
thomwolf committed
28
from .file_utils import cached_path, is_tf_available, is_torch_available
thomwolf's avatar
thomwolf committed
29
30
31

if is_tf_available():
    import tensorflow as tf
thomwolf's avatar
thomwolf committed
32
if is_torch_available():
thomwolf's avatar
thomwolf committed
33
    import torch
34
35
36

logger = logging.getLogger(__name__)

37
38
SPECIAL_TOKENS_MAP_FILE = 'special_tokens_map.json'
ADDED_TOKENS_FILE = 'added_tokens.json'
39
TOKENIZER_CONFIG_FILE = 'tokenizer_config.json'
40
41

class PreTrainedTokenizer(object):
42
43
    """ Base class for all tokenizers.
    Handle all the shared methods for tokenization and special tokens as well as methods dowloading/caching/loading pretrained tokenizers as well as adding tokens to the vocabulary.
44

45
    This class also contain the added tokens in a unified way on top of all tokenizers so we don't have to handle the specific vocabulary augmentation methods of the various underlying dictionary structures (BPE, sentencepiece...).
46

47
48
49
50
51
    Class attributes (overridden by derived classes):

        - ``vocab_files_names``: a python ``dict`` with, as keys, the ``__init__`` keyword name of each vocabulary file required by the model, and as associated values, the filename for saving the associated file (string).
        - ``pretrained_vocab_files_map``: a python ``dict of dict`` the high-level keys being the ``__init__`` keyword name of each vocabulary file required by the model, the low-level being the `short-cut-names` (string) of the pretrained models with, as associated values, the `url` (string) to the associated pretrained vocabulary file.
        - ``max_model_input_sizes``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the pretrained models, and as associated values, the maximum length of the sequence inputs of this model, or None if the model has no maximum input size.
52
        - ``pretrained_init_configuration``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the pretrained models, and as associated values, a dictionnary of specific arguments to pass to the ``__init__``method of the tokenizer class for this pretrained model when loading the tokenizer with the ``from_pretrained()`` method.
53
54
55

    Parameters:

thomwolf's avatar
thomwolf committed
56
        - ``bos_token``: (`Optional`) string: a beginning of sentence token. Will be associated to ``self.bos_token`` and ``self.bos_token_id``
57

thomwolf's avatar
thomwolf committed
58
        - ``eos_token``: (`Optional`) string: an end of sentence token. Will be associated to ``self.eos_token`` and ``self.eos_token_id``
59

thomwolf's avatar
thomwolf committed
60
        - ``unk_token``: (`Optional`) string: an unknown token. Will be associated to ``self.unk_token`` and ``self.unk_token_id``
61

thomwolf's avatar
thomwolf committed
62
        - ``sep_token``: (`Optional`) string: a separation token (e.g. to separate context and query in an input sequence). Will be associated to ``self.sep_token`` and ``self.sep_token_id``
63

thomwolf's avatar
thomwolf committed
64
        - ``pad_token``: (`Optional`) string: a padding token. Will be associated to ``self.pad_token`` and ``self.pad_token_id``
65

thomwolf's avatar
thomwolf committed
66
        - ``cls_token``: (`Optional`) string: a classification token (e.g. to extract a summary of an input sequence leveraging self-attention along the full depth of the model). Will be associated to ``self.cls_token`` and ``self.cls_token_id``
67

thomwolf's avatar
thomwolf committed
68
        - ``mask_token``: (`Optional`) string: a masking token (e.g. when training a model with masked-language modeling). Will be associated to ``self.mask_token`` and ``self.mask_token_id``
69

thomwolf's avatar
thomwolf committed
70
        - ``additional_special_tokens``: (`Optional`) list: a list of additional special tokens. Adding all special tokens here ensure they won't be split by the tokenization process. Will be associated to ``self.additional_special_tokens`` and ``self.additional_special_tokens_ids``
71
72
73
    """
    vocab_files_names = {}
    pretrained_vocab_files_map = {}
74
    pretrained_init_configuration = {}
75
76
    max_model_input_sizes = {}

77
78
79
80
    SPECIAL_TOKENS_ATTRIBUTES = ["bos_token", "eos_token", "unk_token", "sep_token",
                                 "pad_token", "cls_token", "mask_token",
                                 "additional_special_tokens"]

81
82
    padding_side = "right"

83
84
    @property
    def bos_token(self):
85
        """ Beginning of sentence token (string). Log an error if used while not having been set. """
86
87
88
89
90
91
        if self._bos_token is None:
            logger.error("Using bos_token, but it is not set yet.")
        return self._bos_token

    @property
    def eos_token(self):
92
        """ End of sentence token (string). Log an error if used while not having been set. """
93
94
95
96
97
98
        if self._eos_token is None:
            logger.error("Using eos_token, but it is not set yet.")
        return self._eos_token

    @property
    def unk_token(self):
99
        """ Unknown token (string). Log an error if used while not having been set. """
100
101
102
103
104
105
        if self._unk_token is None:
            logger.error("Using unk_token, but it is not set yet.")
        return self._unk_token

    @property
    def sep_token(self):
106
        """ Separation token (string). E.g. separate context and query in an input sequence. Log an error if used while not having been set. """
107
108
109
110
111
112
        if self._sep_token is None:
            logger.error("Using sep_token, but it is not set yet.")
        return self._sep_token

    @property
    def pad_token(self):
113
        """ Padding token (string). Log an error if used while not having been set. """
114
115
116
117
118
119
        if self._pad_token is None:
            logger.error("Using pad_token, but it is not set yet.")
        return self._pad_token

    @property
    def cls_token(self):
120
        """ Classification token (string). E.g. to extract a summary of an input sequence leveraging self-attention along the full depth of the model. Log an error if used while not having been set. """
121
122
123
124
125
126
        if self._cls_token is None:
            logger.error("Using cls_token, but it is not set yet.")
        return self._cls_token

    @property
    def mask_token(self):
127
        """ Mask token (string). E.g. when training a model with masked-language modeling. Log an error if used while not having been set. """
128
129
130
131
132
133
        if self._mask_token is None:
            logger.error("Using mask_token, but it is not set yet.")
        return self._mask_token

    @property
    def additional_special_tokens(self):
134
        """ All the additional special tokens you may want to use (list of strings). Log an error if used while not having been set. """
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
        if self._additional_special_tokens is None:
            logger.error("Using additional_special_tokens, but it is not set yet.")
        return self._additional_special_tokens

    @bos_token.setter
    def bos_token(self, value):
        self._bos_token = value

    @eos_token.setter
    def eos_token(self, value):
        self._eos_token = value

    @unk_token.setter
    def unk_token(self, value):
        self._unk_token = value

    @sep_token.setter
    def sep_token(self, value):
        self._sep_token = value

    @pad_token.setter
    def pad_token(self, value):
        self._pad_token = value

    @cls_token.setter
    def cls_token(self, value):
        self._cls_token = value

    @mask_token.setter
    def mask_token(self, value):
        self._mask_token = value

    @additional_special_tokens.setter
    def additional_special_tokens(self, value):
        self._additional_special_tokens = value

171
172
173
    @property
    def bos_token_id(self):
        """ Id of the beginning of sentence token in the vocabulary. Log an error if used while not having been set. """
174
        return self.convert_tokens_to_ids(self.bos_token)
175
176
177
178

    @property
    def eos_token_id(self):
        """ Id of the end of sentence token in the vocabulary. Log an error if used while not having been set. """
179
        return self.convert_tokens_to_ids(self.eos_token)
180
181

    @property
maru0kun's avatar
maru0kun committed
182
    def unk_token_id(self):
183
        """ Id of the unknown token in the vocabulary. Log an error if used while not having been set. """
184
        return self.convert_tokens_to_ids(self.unk_token)
185
186
187
188

    @property
    def sep_token_id(self):
        """ Id of the separation token in the vocabulary. E.g. separate context and query in an input sequence. Log an error if used while not having been set. """
189
        return self.convert_tokens_to_ids(self.sep_token)
190
191
192
193

    @property
    def pad_token_id(self):
        """ Id of the padding token in the vocabulary. Log an error if used while not having been set. """
194
        return self.convert_tokens_to_ids(self.pad_token)
195

196
197
    @property
    def pad_token_type_id(self):
198
        """ Id of the padding token type in the vocabulary."""
199
200
        return self._pad_token_type_id

201
202
203
    @property
    def cls_token_id(self):
        """ Id of the classification token in the vocabulary. E.g. to extract a summary of an input sequence leveraging self-attention along the full depth of the model. Log an error if used while not having been set. """
204
        return self.convert_tokens_to_ids(self.cls_token)
205
206
207
208

    @property
    def mask_token_id(self):
        """ Id of the mask token in the vocabulary. E.g. when training a model with masked-language modeling. Log an error if used while not having been set. """
209
        return self.convert_tokens_to_ids(self.mask_token)
210
211
212
213

    @property
    def additional_special_tokens_ids(self):
        """ Ids of all the additional special tokens in the vocabulary (list of integers). Log an error if used while not having been set. """
214
        return self.convert_tokens_to_ids(self.additional_special_tokens)
215

216
217
218
219
220
221
222
223
    def __init__(self, max_len=None, **kwargs):
        self._bos_token = None
        self._eos_token = None
        self._unk_token = None
        self._sep_token = None
        self._pad_token = None
        self._cls_token = None
        self._mask_token = None
224
        self._pad_token_type_id = 0
225
226
227
        self._additional_special_tokens = []

        self.max_len = max_len if max_len is not None else int(1e12)
228

229
230
231
        # Padding side is right by default and over-riden in subclsses. If specified in the kwargs, it is changed.
        self.padding_side = kwargs.pop('padding_side', self.padding_side)
        
232
        # Added tokens
233
234
235
        self.added_tokens_encoder = {}
        self.added_tokens_decoder = {}

236
237
238
239
        # inputs and kwargs for saving and re-loading (see ``from_pretrained`` and ``save_pretrained``)
        self.init_inputs = ()
        self.init_kwargs = {}

240
        for key, value in kwargs.items():
241
            if key in self.SPECIAL_TOKENS_ATTRIBUTES:
242
243
244
245
                if key == 'additional_special_tokens':
                    assert isinstance(value, (list, tuple)) and all(isinstance(t, str) or (six.PY2 and isinstance(t, unicode)) for t in value)
                else:
                    assert isinstance(value, str) or (six.PY2 and isinstance(value, unicode))
246
247
248
                setattr(self, key, value)


249
250
    @classmethod
    def from_pretrained(cls, *inputs, **kwargs):
LysandreJik's avatar
Doc  
LysandreJik committed
251
        r"""
252
        Instantiate a :class:`~transformers.PreTrainedTokenizer` (or a derived class) from a predefined tokenizer.
253

LysandreJik's avatar
Doc  
LysandreJik committed
254
        Args:
255
256
257
            pretrained_model_name_or_path: either:

                - a string with the `shortcut name` of a predefined tokenizer to load from cache or download, e.g.: ``bert-base-uncased``.
258
                - a path to a `directory` containing vocabulary files required by the tokenizer, for instance saved using the :func:`~transformers.PreTrainedTokenizer.save_pretrained` method, e.g.: ``./my_model_directory/``.
259
260
261
262
263
                - (not applicable to all derived classes) a path or url to a single saved vocabulary file if and only if the tokenizer only requires a single vocabulary file (e.g. Bert, XLNet), e.g.: ``./my_model_directory/vocab.txt``.

            cache_dir: (`optional`) string:
                Path to a directory in which a downloaded predefined tokenizer vocabulary files should be cached if the standard cache should not be used.

264
265
266
            force_download: (`optional`) boolean, default False:
                Force to (re-)download the vocabulary files and override the cached versions if they exists.

267
268
269
            resume_download: (`optional`) boolean, default False:
                Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.

270
271
272
273
            proxies: (`optional`) dict, default None:
                A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
                The proxies are used on each request.

274
275
            inputs: (`optional`) positional arguments: will be passed to the Tokenizer ``__init__`` method.

276
            kwargs: (`optional`) keyword arguments: will be passed to the Tokenizer ``__init__`` method. Can be used to set special tokens like ``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``, ``additional_special_tokens``. See parameters in the doc string of :class:`~transformers.PreTrainedTokenizer` for details.
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297

        Examples::

            # We can't instantiate directly the base class `PreTrainedTokenizer` so let's show our examples on a derived class: BertTokenizer

            # Download vocabulary from S3 and cache.
            tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

            # If vocabulary files are in a directory (e.g. tokenizer was saved using `save_pretrained('./test/saved_model/')`)
            tokenizer = BertTokenizer.from_pretrained('./test/saved_model/')

            # If the tokenizer uses a single vocabulary file, you can point directly to this file
            tokenizer = BertTokenizer.from_pretrained('./test/saved_model/my_vocab.txt')

            # You can link tokens to special vocabulary when instantiating
            tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', unk_token='<unk>')
            # You should be sure '<unk>' is in the vocabulary when doing that.
            # Otherwise use tokenizer.add_special_tokens({'unk_token': '<unk>'}) instead)
            assert tokenizer.unk_token == '<unk>'

        """
298
299
        return cls._from_pretrained(*inputs, **kwargs)

300

301
    @classmethod
302
    def _from_pretrained(cls, pretrained_model_name_or_path, *init_inputs, **kwargs):
thomwolf's avatar
thomwolf committed
303
        cache_dir = kwargs.pop('cache_dir', None)
304
        force_download = kwargs.pop('force_download', False)
305
        resume_download = kwargs.pop('resume_download', False)
306
        proxies = kwargs.pop('proxies', None)
thomwolf's avatar
thomwolf committed
307

308
309
        s3_models = list(cls.max_model_input_sizes.keys())
        vocab_files = {}
310
        init_configuration = {}
311
        if pretrained_model_name_or_path in s3_models:
thomwolf's avatar
thomwolf committed
312
            # Get the vocabulary from AWS S3 bucket
313
314
            for file_id, map_list in cls.pretrained_vocab_files_map.items():
                vocab_files[file_id] = map_list[pretrained_model_name_or_path]
315
316
            if cls.pretrained_init_configuration and pretrained_model_name_or_path in cls.pretrained_init_configuration:
                init_configuration = cls.pretrained_init_configuration[pretrained_model_name_or_path]
317
        else:
thomwolf's avatar
thomwolf committed
318
            # Get the vocabulary from local files
319
320
321
322
323
            logger.info(
                "Model name '{}' not found in model shortcut name list ({}). "
                "Assuming '{}' is a path or url to a directory containing tokenizer files.".format(
                    pretrained_model_name_or_path, ', '.join(s3_models),
                    pretrained_model_name_or_path))
thomwolf's avatar
thomwolf committed
324
325
326

            # Look for the tokenizer main vocabulary files
            for file_id, file_name in cls.vocab_files_names.items():
327
                if os.path.isdir(pretrained_model_name_or_path):
thomwolf's avatar
thomwolf committed
328
                    # If a directory is provided we look for the standard filenames
329
330
                    full_file_name = os.path.join(pretrained_model_name_or_path, file_name)
                else:
thomwolf's avatar
thomwolf committed
331
                    # If a path to a file is provided we use it (will only work for non-BPE tokenizer using a single vocabulary file)
332
333
                    full_file_name = pretrained_model_name_or_path
                if not os.path.exists(full_file_name):
334
                    logger.info("Didn't find file {}. We won't load it.".format(full_file_name))
335
336
                    full_file_name = None
                vocab_files[file_id] = full_file_name
thomwolf's avatar
thomwolf committed
337
338

            # Look for the additional tokens files
339
340
341
342
            additional_files_names = {'added_tokens_file': ADDED_TOKENS_FILE,
                                      'special_tokens_map_file': SPECIAL_TOKENS_MAP_FILE,
                                      'tokenizer_config_file': TOKENIZER_CONFIG_FILE,
                                      }
thomwolf's avatar
thomwolf committed
343
344
345
346
347
348

            # If a path to a file was provided, get the parent directory
            saved_directory = pretrained_model_name_or_path
            if os.path.exists(saved_directory) and not os.path.isdir(saved_directory):
                saved_directory = os.path.dirname(saved_directory)

349
            for file_id, file_name in additional_files_names.items():
thomwolf's avatar
thomwolf committed
350
351
352
353
354
355
                full_file_name = os.path.join(saved_directory, file_name)
                if not os.path.exists(full_file_name):
                    logger.info("Didn't find file {}. We won't load it.".format(full_file_name))
                    full_file_name = None
                vocab_files[file_id] = full_file_name

356
            if all(full_file_name is None for full_file_name in vocab_files.values()):
thomwolf's avatar
thomwolf committed
357
358
359
360
                raise EnvironmentError(
                    "Model name '{}' was not found in tokenizers model name list ({}). "
                    "We assumed '{}' was a path or url to a directory containing vocabulary files "
                    "named {} but couldn't find such vocabulary files at this path or url.".format(
361
                        pretrained_model_name_or_path, ', '.join(s3_models),
362
                        pretrained_model_name_or_path,
thomwolf's avatar
thomwolf committed
363
                        list(cls.vocab_files_names.values())))
364
365

        # Get files from url, cache, or disk depending on the case
366
367
368
369
370
371
        try:
            resolved_vocab_files = {}
            for file_id, file_path in vocab_files.items():
                if file_path is None:
                    resolved_vocab_files[file_id] = None
                else:
372
                    resolved_vocab_files[file_id] = cached_path(file_path, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download)
thomwolf's avatar
thomwolf committed
373
        except EnvironmentError:
374
            if pretrained_model_name_or_path in s3_models:
thomwolf's avatar
thomwolf committed
375
                msg = "Couldn't reach server at '{}' to download vocabulary files."
376
            else:
thomwolf's avatar
thomwolf committed
377
378
379
                msg = "Model name '{}' was not found in tokenizers model name list ({}). " \
                    "We assumed '{}' was a path or url to a directory containing vocabulary files " \
                    "named {}, but couldn't find such vocabulary files at this path or url.".format(
380
                        pretrained_model_name_or_path, ', '.join(s3_models),
thomwolf's avatar
thomwolf committed
381
382
383
384
                        pretrained_model_name_or_path,
                        list(cls.vocab_files_names.values()))

            raise EnvironmentError(msg)
385
386
387
388
389
390
391
392

        for file_id, file_path in vocab_files.items():
            if file_path == resolved_vocab_files[file_id]:
                logger.info("loading file {}".format(file_path))
            else:
                logger.info("loading file {} from cache at {}".format(
                    file_path, resolved_vocab_files[file_id]))

393
394
395
396
        # Prepare tokenizer initialization kwargs
        # Did we saved some inputs and kwargs to reload ?
        tokenizer_config_file = resolved_vocab_files.pop('tokenizer_config_file', None)
        if tokenizer_config_file is not None:
397
398
            with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle:
                init_kwargs = json.load(tokenizer_config_handle)
399
            saved_init_inputs = init_kwargs.pop('init_inputs', ())
400
401
402
403
404
405
            if not init_inputs:
                init_inputs = saved_init_inputs
        else:
            init_kwargs = init_configuration

        # Update with newly provided kwargs
406
407
        init_kwargs.update(kwargs)

408
        # Set max length if needed
409
410
411
412
        if pretrained_model_name_or_path in cls.max_model_input_sizes:
            # if we're using a pretrained model, ensure the tokenizer
            # wont index sequences longer than the number of positional embeddings
            max_len = cls.max_model_input_sizes[pretrained_model_name_or_path]
413
            if max_len is not None and isinstance(max_len, (int, float)):
414
                init_kwargs['max_len'] = min(init_kwargs.get('max_len', int(1e12)), max_len)
415

416
        # Merge resolved_vocab_files arguments in init_kwargs.
417
418
        added_tokens_file = resolved_vocab_files.pop('added_tokens_file', None)
        special_tokens_map_file = resolved_vocab_files.pop('special_tokens_map_file', None)
thomwolf's avatar
thomwolf committed
419
        for args_name, file_path in resolved_vocab_files.items():
420
421
            if args_name not in init_kwargs:
                init_kwargs[args_name] = file_path
422
        if special_tokens_map_file is not None:
423
424
            with open(special_tokens_map_file, encoding="utf-8") as special_tokens_map_handle:
                special_tokens_map = json.load(special_tokens_map_handle)
425
            for key, value in special_tokens_map.items():
426
427
                if key not in init_kwargs:
                    init_kwargs[key] = value
thomwolf's avatar
thomwolf committed
428

429
        # Instantiate tokenizer.
430
431
432
433
434
        tokenizer = cls(*init_inputs, **init_kwargs)

        # Save inputs and kwargs for saving and re-loading with ``save_pretrained``
        tokenizer.init_inputs = init_inputs
        tokenizer.init_kwargs = init_kwargs
435

436
437
        # Add supplementary tokens.
        if added_tokens_file is not None:
438
439
            with open(added_tokens_file, encoding="utf-8") as added_tokens_handle:
                added_tok_encoder = json.load(added_tokens_handle)
440
441
442
443
            added_tok_decoder = {v:k for k, v in added_tok_encoder.items()}
            tokenizer.added_tokens_encoder.update(added_tok_encoder)
            tokenizer.added_tokens_decoder.update(added_tok_decoder)

444
445
        return tokenizer

thomwolf's avatar
thomwolf committed
446

447
    def save_pretrained(self, save_directory):
448
449
450
451
452
453
        """ Save the tokenizer vocabulary files together with:
                - added tokens,
                - special-tokens-to-class-attributes-mapping,
                - tokenizer instantiation positional and keywords inputs (e.g. do_lower_case for Bert).

            This won't save modifications other than (added tokens and special token mapping) you may have
Julien Chaumond's avatar
Julien Chaumond committed
454
            applied to the tokenizer after the instantiation (e.g. modifying tokenizer.do_lower_case after creation).
455

456
            This method make sure the full tokenizer can then be re-loaded using the :func:`~transformers.PreTrainedTokenizer.from_pretrained` class method.
457
458
459
460
461
462
463
        """
        if not os.path.isdir(save_directory):
            logger.error("Saving directory ({}) should be a directory".format(save_directory))
            return

        special_tokens_map_file = os.path.join(save_directory, SPECIAL_TOKENS_MAP_FILE)
        added_tokens_file = os.path.join(save_directory, ADDED_TOKENS_FILE)
464
465
466
467
        tokenizer_config_file = os.path.join(save_directory, TOKENIZER_CONFIG_FILE)

        tokenizer_config = copy.deepcopy(self.init_kwargs)
        tokenizer_config['init_inputs'] = copy.deepcopy(self.init_inputs)
468
469
        for file_id in self.vocab_files_names.keys():
            tokenizer_config.pop(file_id, None)
470
471
472

        with open(tokenizer_config_file, 'w', encoding='utf-8') as f:
            f.write(json.dumps(tokenizer_config, ensure_ascii=False))
473
474
475
476
477

        with open(special_tokens_map_file, 'w', encoding='utf-8') as f:
            f.write(json.dumps(self.special_tokens_map, ensure_ascii=False))

        with open(added_tokens_file, 'w', encoding='utf-8') as f:
thomwolf's avatar
thomwolf committed
478
            if self.added_tokens_encoder:
479
                out_str = json.dumps(self.added_tokens_encoder, ensure_ascii=False)
thomwolf's avatar
thomwolf committed
480
481
482
            else:
                out_str = u"{}"
            f.write(out_str)
483
484
485
486
487
488
489

        vocab_files = self.save_vocabulary(save_directory)

        return vocab_files + (special_tokens_map_file, added_tokens_file)


    def save_vocabulary(self, save_directory):
490
        """ Save the tokenizer vocabulary to a directory. This method does *NOT* save added tokens
491
            and special token mappings.
492

493
            Please use :func:`~transformers.PreTrainedTokenizer.save_pretrained` `()` to save the full Tokenizer state if you want to reload it using the :func:`~transformers.PreTrainedTokenizer.from_pretrained` class method.
494
        """
thomwolf's avatar
thomwolf committed
495
496
        raise NotImplementedError

497
498

    def vocab_size(self):
499
        """ Size of the base vocabulary (without the added tokens) """
thomwolf's avatar
thomwolf committed
500
501
        raise NotImplementedError

502
503

    def __len__(self):
504
        """ Size of the full vocabulary with the added tokens """
505
506
507
508
        return self.vocab_size + len(self.added_tokens_encoder)


    def add_tokens(self, new_tokens):
LysandreJik's avatar
Doc  
LysandreJik committed
509
510
        """
        Add a list of new tokens to the tokenizer class. If the new tokens are not in the
511
512
        vocabulary, they are added to it with indices starting from length of the current vocabulary.

LysandreJik's avatar
Doc  
LysandreJik committed
513
514
        Args:
            new_tokens: list of string. Each string is a token to add. Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them).
515

LysandreJik's avatar
Doc  
LysandreJik committed
516
517
        Returns:
            Number of tokens added to the vocabulary.
518
519
520
521
522
523
524
525
526
527

        Examples::

            # Let's see how to increase the vocabulary of Bert model and tokenizer
            tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
            model = BertModel.from_pretrained('bert-base-uncased')

            num_added_toks = tokenizer.add_tokens(['new_tok1', 'my_new-tok2'])
            print('We have added', num_added_toks, 'tokens')
            model.resize_token_embeddings(len(tokenizer))  # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
528
529
530
531
532
533
        """
        if not new_tokens:
            return 0

        to_add_tokens = []
        for token in new_tokens:
534
            assert isinstance(token, str) or (six.PY2 and isinstance(token, unicode))
535
            if self.init_kwargs.get('do_lower_case', False) and token not in self.all_special_tokens:
536
                token = token.lower()
thomwolf's avatar
thomwolf committed
537
            if token != self.unk_token and \
danai-antoniou's avatar
danai-antoniou committed
538
539
                    self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token) and \
                    token not in to_add_tokens:
540
541
542
543
544
545
546
547
548
549
                to_add_tokens.append(token)
                logger.info("Adding %s to the vocabulary", token)

        added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(to_add_tokens))
        added_tok_decoder = {v:k for k, v in added_tok_encoder.items()}
        self.added_tokens_encoder.update(added_tok_encoder)
        self.added_tokens_decoder.update(added_tok_decoder)

        return len(to_add_tokens)

550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
    def num_added_tokens(self, pair=False):
        """
        Returns the number of added tokens when encoding a sequence with special tokens.

        Note:
            This encodes inputs and checks the number of added tokens, and is therefore not efficient. Do not put this
            inside your training loop.

        Args:
            pair: Returns the number of added tokens in the case of a sequence pair if set to True, returns the
                number of added tokens in the case of a single sequence if set to False.

        Returns:
            Number of tokens added to sequences
        """
565
566
567
        token_ids_0 = []
        token_ids_1 = []
        return len(self.build_inputs_with_special_tokens(token_ids_0, token_ids_1 if pair else None))
568
569

    def add_special_tokens(self, special_tokens_dict):
LysandreJik's avatar
Doc  
LysandreJik committed
570
571
572
573
        """
        Add a dictionary of special tokens (eos, pad, cls...) to the encoder and link them
        to class attributes. If special tokens are NOT in the vocabulary, they are added
        to it (indexed starting from the last index of the current vocabulary).
574

thomwolf's avatar
thomwolf committed
575
576
577
578
579
580
581
        Using `add_special_tokens` will ensure your special tokens can be used in several ways:

        - special tokens are carefully handled by the tokenizer (they are never split)
        - you can easily refer to special tokens using tokenizer class attributes like `tokenizer.cls_token`. This makes it easy to develop model-agnostic training and fine-tuning scripts.

        When possible, special tokens are already registered for provided pretrained models (ex: BertTokenizer cls_token is already registered to be '[CLS]' and XLM's one is also registered to be '</s>')

LysandreJik's avatar
Doc  
LysandreJik committed
582
583
584
585
        Args:
            special_tokens_dict: dict of string. Keys should be in the list of predefined special attributes:
                [``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``,
                ``additional_special_tokens``].
586

LysandreJik's avatar
Doc  
LysandreJik committed
587
                Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them).
588

LysandreJik's avatar
Doc  
LysandreJik committed
589
590
        Returns:
            Number of tokens added to the vocabulary.
591
592
593
594
595
596
597
598
599
600
601
602
603
604

        Examples::

            # Let's see how to add a new classification token to GPT-2
            tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
            model = GPT2Model.from_pretrained('gpt2')

            special_tokens_dict = {'cls_token': '<CLS>'}

            num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
            print('We have added', num_added_toks, 'tokens')
            model.resize_token_embeddings(len(tokenizer))  # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.

            assert tokenizer.cls_token == '<CLS>'
605
606
607
608
        """
        if not special_tokens_dict:
            return 0

609
        added_tokens = 0
610
        for key, value in special_tokens_dict.items():
611
            assert key in self.SPECIAL_TOKENS_ATTRIBUTES
612
613
614
615
616
617
            if key == 'additional_special_tokens':
                assert isinstance(value, (list, tuple)) and all(isinstance(t, str) or (six.PY2 and isinstance(t, unicode)) for t in value)
                added_tokens += self.add_tokens(value)
            else:
                assert isinstance(value, str) or (six.PY2 and isinstance(value, unicode))
                added_tokens += self.add_tokens([value])
618
619
620
            logger.info("Assigning %s to the %s key of the tokenizer", value, key)
            setattr(self, key, value)

621
        return added_tokens
622
623
624
625
626
627
628

    def tokenize(self, text, **kwargs):
        """ Converts a string in a sequence of tokens (string), using the tokenizer.
            Split in words for word-based vocabulary or sub-words for sub-word-based
            vocabularies (BPE/SentencePieces/WordPieces).

            Take care of added tokens.
Lysandre's avatar
wip  
Lysandre committed
629
630
631
632

            text: The sequence to be encoded.
            return_tokens_mapped_to_origin: (optional) Set to True to return the index of each token in the initial whitespace tokenization. (default False).
            **kwargs: passed to the child `self.tokenize()` method
633
        """
634
635
636
637
638
639
640
641
642
643
        def lowercase_text(t):
            # convert non-special tokens to lowercase
            escaped_special_toks = [re.escape(s_tok) for s_tok in self.all_special_tokens]
            pattern = r'(^' + r'|'.join(escaped_special_toks) + r')|' + \
                      r'(.+?)'
            return re.sub(
                pattern,
                lambda m: m.groups()[0] or m.groups()[1].lower(),
                t)

644
        if self.init_kwargs.get('do_lower_case', False):
645
            text = lowercase_text(text)
646

647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
        def split_on_token(tok, text):
            result = []
            split_text = text.split(tok)
            for i, sub_text in enumerate(split_text):
                sub_text = sub_text.strip()
                if i == 0 and not sub_text:
                    result += [tok]
                elif i == len(split_text) - 1:
                    if sub_text:
                        result += [sub_text]
                    else:
                        pass
                else:
                    if sub_text:
                        result += [sub_text]
                    result += [tok]
            return result

665
666
667
668
669
        def split_on_tokens(tok_list, text):
            if not text:
                return []
            if not tok_list:
                return self._tokenize(text, **kwargs)
670
671
672
673
674
675
676
677
678
679
680
681
682

            tokenized_text = []
            text_list = [text]
            for tok in tok_list:
                tokenized_text = []
                for sub_text in text_list:
                    if sub_text not in self.added_tokens_encoder \
                            and sub_text not in self.all_special_tokens:
                        tokenized_text += split_on_token(tok, sub_text)
                    else:
                        tokenized_text += [sub_text]
                text_list = tokenized_text

683
            return list(itertools.chain.from_iterable((self._tokenize(token, **kwargs) if token not \
684
                    in self.added_tokens_encoder and token not in self.all_special_tokens \
685
                    else [token] for token in tokenized_text)))
686

687
        added_tokens = list(self.added_tokens_encoder.keys()) + self.all_special_tokens
688
689
690
691
692
693
694
695
        tokenized_text = split_on_tokens(added_tokens, text)
        return tokenized_text

    def _tokenize(self, text, **kwargs):
        """ Converts a string in a sequence of tokens (string), using the tokenizer.
            Split in words for word-based vocabulary or sub-words for sub-word-based
            vocabularies (BPE/SentencePieces/WordPieces).

696
            Do NOT take care of added tokens.
697
        """
thomwolf's avatar
thomwolf committed
698
699
        raise NotImplementedError

700
    def convert_tokens_to_ids(self, tokens):
701
702
        """ Converts a single token, or a sequence of tokens, (str/unicode) in a single integer id
            (resp. a sequence of ids), using the vocabulary.
703
        """
704
705
706
        if tokens is None:
            return None

707
        if isinstance(tokens, str) or (six.PY2 and isinstance(tokens, unicode)):
708
            return self._convert_token_to_id_with_added_voc(tokens)
709
710
711

        ids = []
        for token in tokens:
712
            ids.append(self._convert_token_to_id_with_added_voc(token))
713
714
        return ids

715
    def _convert_token_to_id_with_added_voc(self, token):
716
717
718
        if token is None:
            return None

719
720
721
722
723
        if token in self.added_tokens_encoder:
            return self.added_tokens_encoder[token]
        return self._convert_token_to_id(token)

    def _convert_token_to_id(self, token):
thomwolf's avatar
thomwolf committed
724
725
        raise NotImplementedError

thomwolf's avatar
thomwolf committed
726
    def encode(self,
Lysandre's avatar
Remove  
Lysandre committed
727
728
729
730
731
732
               text,
               text_pair=None,
               add_special_tokens=True,
               max_length=None,
               stride=0,
               truncation_strategy='longest_first',
733
               pad_to_max_length=False,
Lysandre's avatar
Remove  
Lysandre committed
734
735
               return_tensors=None,
               **kwargs):
LysandreJik's avatar
Doc  
LysandreJik committed
736
737
        """
        Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
738

LysandreJik's avatar
Doc  
LysandreJik committed
739
740
741
        Same as doing ``self.convert_tokens_to_ids(self.tokenize(text))``.

        Args:
LysandreJik's avatar
LysandreJik committed
742
743
744
745
746
747
            text: The first sequence to be encoded. This can be a string, a list of strings (tokenized string using
                the `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids`
                method)
            text_pair: Optional second sequence to be encoded. This can be a string, a list of strings (tokenized
                string using the `tokenize` method) or a list of integers (tokenized string ids using the
                `convert_tokens_to_ids` method)
LysandreJik's avatar
Doc  
LysandreJik committed
748
749
            add_special_tokens: if set to ``True``, the sequences will be encoded with the special tokens relative
                to their model.
thomwolf's avatar
thomwolf committed
750
751
752
            max_length: if set to a number, will limit the total sequence returned so that it has a maximum length.
                If there are overflowing tokens, those will be added to the returned dictionary
            stride: if set to a number along with max_length, the overflowing tokens returned will contain some tokens
753
                from the main sequence returned. The value of this argument defines the number of additional tokens.
thomwolf's avatar
fixes  
thomwolf committed
754
755
756
757
758
759
            truncation_strategy: string selected in the following options:
                - 'longest_first' (default) Iteratively reduce the inputs sequence until the input is under max_length
                    starting from the longest one at each token (when there is a pair of input sequences)
                - 'only_first': Only truncate the first sequence
                - 'only_second': Only truncate the second sequence
                - 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length)
760
761
762
            pad_to_max_length: if set to True, the returned sequences will be padded according to the model's padding side and
                padding index, up to their max length. If no max length is specified, the padding is done up to the model's max length.
                The tokenizer padding sides are handled by the following strings:
763
764
                - 'left': pads on the left of the sequences
                - 'right': pads on the right of the sequences   
765
                Defaults to False: no padding.
thomwolf's avatar
thomwolf committed
766
767
            return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
                or PyTorch torch.Tensor instead of a list of python integers.
thomwolf's avatar
thomwolf committed
768
            **kwargs: passed to the `self.tokenize()` method
769
        """
thomwolf's avatar
thomwolf committed
770
771
772
773
774
        encoded_inputs = self.encode_plus(text,
                                          text_pair=text_pair,
                                          max_length=max_length,
                                          add_special_tokens=add_special_tokens,
                                          stride=stride,
thomwolf's avatar
fixes  
thomwolf committed
775
                                          truncation_strategy=truncation_strategy,
776
                                          pad_to_max_length=pad_to_max_length,
thomwolf's avatar
thomwolf committed
777
778
                                          return_tensors=return_tensors,
                                          **kwargs)
thomwolf's avatar
thomwolf committed
779
780

        return encoded_inputs["input_ids"]
781

782
783
784
    def encode_plus(self,
                    text,
                    text_pair=None,
Lysandre's avatar
Remove  
Lysandre committed
785
                    add_special_tokens=True,
786
787
                    max_length=None,
                    stride=0,
thomwolf's avatar
fixes  
thomwolf committed
788
                    truncation_strategy='longest_first',
789
                    pad_to_max_length=False,
thomwolf's avatar
thomwolf committed
790
                    return_tensors=None,
791
792
793
794
                    return_token_type_ids=True,
                    return_attention_mask=True,
                    return_overflowing_tokens=False,
                    return_special_tokens_mask=False,
795
                    **kwargs):
796
        """
thomwolf's avatar
thomwolf committed
797
798
        Returns a dictionary containing the encoded sequence or sequence pair and additional informations:
        the mask for sequence classification and the overflowing elements if a ``max_length`` is specified.
LysandreJik's avatar
Doc  
LysandreJik committed
799
800

        Args:
LysandreJik's avatar
LysandreJik committed
801
802
803
804
805
806
            text: The first sequence to be encoded. This can be a string, a list of strings (tokenized string using
                the `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids`
                method)
            text_pair: Optional second sequence to be encoded. This can be a string, a list of strings (tokenized
                string using the `tokenize` method) or a list of integers (tokenized string ids using the
                `convert_tokens_to_ids` method)
LysandreJik's avatar
Doc  
LysandreJik committed
807
808
            add_special_tokens: if set to ``True``, the sequences will be encoded with the special tokens relative
                to their model.
809
            max_length: if set to a number, will limit the total sequence returned so that it has a maximum length.
LysandreJik's avatar
LysandreJik committed
810
811
                If there are overflowing tokens, those will be added to the returned dictionary
            stride: if set to a number along with max_length, the overflowing tokens returned will contain some tokens
812
                from the main sequence returned. The value of this argument defines the number of additional tokens.
thomwolf's avatar
fixes  
thomwolf committed
813
814
815
816
817
818
            truncation_strategy: string selected in the following options:
                - 'longest_first' (default) Iteratively reduce the inputs sequence until the input is under max_length
                    starting from the longest one at each token (when there is a pair of input sequences)
                - 'only_first': Only truncate the first sequence
                - 'only_second': Only truncate the second sequence
                - 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length)
819
820
821
            pad_to_max_length: if set to True, the returned sequences will be padded according to the model's padding side and
                padding index, up to their max length. If no max length is specified, the padding is done up to the model's max length.
                The tokenizer padding sides are handled by the following strings:
822
823
                - 'left': pads on the left of the sequences
                - 'right': pads on the right of the sequences   
824
                Defaults to False: no padding.
thomwolf's avatar
thomwolf committed
825
826
            return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
                or PyTorch torch.Tensor instead of a list of python integers.
827
828
829
830
            return_token_type_ids: (optional) Set to False to avoid returning token_type_ids (default True).
            return_attention_mask: (optional) Set to False to avoir returning attention mask (default True)
            return_overflowing_tokens: (optional) Set to True to return overflowing token information (default False).
            return_special_tokens_mask: (optional) Set to True to return special tokens mask information (default False).
thomwolf's avatar
thomwolf committed
831
            **kwargs: passed to the `self.tokenize()` method
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852

        Return:
            A Dictionary of shape::

                {
                    input_ids: list[int],
                    token_type_ids: list[int] if return_token_type_ids is True (default)
                    attention_mask: list[int] if return_attention_mask is True (default)
                    overflowing_tokens: list[int] if a ``max_length`` is specified and return_overflowing_tokens is True
                    num_truncated_tokens: int if a ``max_length`` is specified and return_overflowing_tokens is True
                    special_tokens_mask: list[int] if ``add_special_tokens`` if set to ``True`` and return_special_tokens_mask is True
                }

            With the fields:
                ``input_ids``: list of token ids to be fed to a model
                ``token_type_ids``: list of token type ids to be fed to a model
                ``attention_mask``: list of indices specifying which tokens should be attended to by the model
                ``overflowing_tokens``: list of overflowing tokens if a max length is specified.
                ``num_truncated_tokens``: number of overflowing tokens a ``max_length`` is specified
                ``special_tokens_mask``: if adding special tokens, this is a list of [0, 1], with 0 specifying special added
                tokens and 1 specifying sequence tokens.
853
        """
854

LysandreJik's avatar
LysandreJik committed
855
856
        def get_input_ids(text):
            if isinstance(text, six.string_types):
thomwolf's avatar
thomwolf committed
857
                return self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
LysandreJik's avatar
LysandreJik committed
858
            elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], six.string_types):
thomwolf's avatar
thomwolf committed
859
                return self.convert_tokens_to_ids(text)
LysandreJik's avatar
LysandreJik committed
860
            elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
thomwolf's avatar
thomwolf committed
861
                return text
862
            else:
LysandreJik's avatar
LysandreJik committed
863
864
                raise ValueError("Input is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers.")

thomwolf's avatar
thomwolf committed
865
866
        first_ids = get_input_ids(text)
        second_ids = get_input_ids(text_pair) if text_pair is not None else None
867

thomwolf's avatar
thomwolf committed
868
869
870
        return self.prepare_for_model(first_ids,
                                      pair_ids=second_ids,
                                      max_length=max_length,
871
                                      pad_to_max_length=pad_to_max_length,
thomwolf's avatar
thomwolf committed
872
873
                                      add_special_tokens=add_special_tokens,
                                      stride=stride,
thomwolf's avatar
fixes  
thomwolf committed
874
                                      truncation_strategy=truncation_strategy,
875
876
877
878
879
                                      return_tensors=return_tensors,
                                      return_attention_mask=return_attention_mask,
                                      return_token_type_ids=return_token_type_ids,
                                      return_overflowing_tokens=return_overflowing_tokens,
                                      return_special_tokens_mask=return_special_tokens_mask)
880

Lysandre's avatar
Remove  
Lysandre committed
881
    def prepare_for_model(self, ids, pair_ids=None, max_length=None, add_special_tokens=True, stride=0,
882
                          truncation_strategy='longest_first',
883
                          pad_to_max_length=False,
884
885
886
887
888
                          return_tensors=None,
                          return_token_type_ids=True,
                          return_attention_mask=True,
                          return_overflowing_tokens=False,
                          return_special_tokens_mask=False):
LysandreJik's avatar
LysandreJik committed
889
        """
thomwolf's avatar
thomwolf committed
890
891
        Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model.
        It adds special tokens, truncates
LysandreJik's avatar
LysandreJik committed
892
893
894
895
896
897
        sequences if overflowing while taking into account the special tokens and manages a window stride for
        overflowing tokens

        Args:
            ids: list of tokenized input ids. Can be obtained from a string by chaining the
                `tokenize` and `convert_tokens_to_ids` methods.
thomwolf's avatar
thomwolf committed
898
            pair_ids: Optional second list of input ids. Can be obtained from a string by chaining the
LysandreJik's avatar
LysandreJik committed
899
900
                `tokenize` and `convert_tokens_to_ids` methods.
            max_length: maximum length of the returned list. Will truncate by taking into account the special tokens.
thomwolf's avatar
thomwolf committed
901
902
            add_special_tokens: if set to ``True``, the sequences will be encoded with the special tokens relative
                to their model.
LysandreJik's avatar
LysandreJik committed
903
904
            stride: window stride for overflowing tokens. Can be useful for edge effect removal when using sequential
                list of inputs.
905
906
907
908
909
910
            truncation_strategy: string selected in the following options:
                - 'longest_first' (default) Iteratively reduce the inputs sequence until the input is under max_length
                    starting from the longest one at each token (when there is a pair of input sequences)
                - 'only_first': Only truncate the first sequence
                - 'only_second': Only truncate the second sequence
                - 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length)
911
912
913
            pad_to_max_length: if set to True, the returned sequences will be padded according to the model's padding side and
                padding index, up to their max length. If no max length is specified, the padding is done up to the model's max length.
                The tokenizer padding sides are handled by the following strings:
914
                - 'left': pads on the left of the sequences
915
916
                - 'right': pads on the right of the sequences   
                Defaults to False: no padding.
thomwolf's avatar
thomwolf committed
917
918
            return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
                or PyTorch torch.Tensor instead of a list of python integers.
919
920
921
922
            return_token_type_ids: (optional) Set to False to avoid returning token_type_ids (default True).
            return_attention_mask: (optional) Set to False to avoir returning attention mask (default True)
            return_overflowing_tokens: (optional) Set to True to return overflowing token information (default False).
            return_special_tokens_mask: (optional) Set to True to return special tokens mask information (default False).
LysandreJik's avatar
LysandreJik committed
923
924

        Return:
LysandreJik's avatar
LysandreJik committed
925
926
927
928
            A Dictionary of shape::

                {
                    input_ids: list[int],
thomwolf's avatar
thomwolf committed
929
930
931
932
                    token_type_ids: list[int] if return_token_type_ids is True (default)
                    overflowing_tokens: list[int] if a ``max_length`` is specified and return_overflowing_tokens is True
                    num_truncated_tokens: int if a ``max_length`` is specified and return_overflowing_tokens is True
                    special_tokens_mask: list[int] if ``add_special_tokens`` if set to ``True`` and return_special_tokens_mask is True
LysandreJik's avatar
LysandreJik committed
933
934
935
                }

            With the fields:
thomwolf's avatar
thomwolf committed
936
937
                ``input_ids``: list of token ids to be fed to a model
                ``token_type_ids``: list of token type ids to be fed to a model
LysandreJik's avatar
LysandreJik committed
938
939

                ``overflowing_tokens``: list of overflowing tokens if a max length is specified.
thomwolf's avatar
thomwolf committed
940
                ``num_truncated_tokens``: number of overflowing tokens a ``max_length`` is specified
941
                ``special_tokens_mask``: if adding special tokens, this is a list of [0, 1], with 0 specifying special added
LysandreJik's avatar
LysandreJik committed
942
                tokens and 1 specifying sequence tokens.
LysandreJik's avatar
LysandreJik committed
943
        """
thomwolf's avatar
thomwolf committed
944
945
946
        pair = bool(pair_ids is not None)
        len_ids = len(ids)
        len_pair_ids = len(pair_ids) if pair else 0
947

thomwolf's avatar
thomwolf committed
948
        encoded_inputs = {}
thomwolf's avatar
thomwolf committed
949
950

        # Handle max sequence length
951
952
953
954
        total_len = len_ids + len_pair_ids + (self.num_added_tokens(pair=pair) if add_special_tokens else 0)
        if max_length and total_len > max_length:
            ids, pair_ids, overflowing_tokens = self.truncate_sequences(ids, pair_ids=pair_ids,
                                                                        num_tokens_to_remove=total_len-max_length,
thomwolf's avatar
fixes  
thomwolf committed
955
956
                                                                        truncation_strategy=truncation_strategy,
                                                                        stride=stride)
thomwolf's avatar
thomwolf committed
957
958
959
            if return_overflowing_tokens:
                encoded_inputs["overflowing_tokens"] = overflowing_tokens
                encoded_inputs["num_truncated_tokens"] = total_len - max_length
960

thomwolf's avatar
thomwolf committed
961
        # Handle special_tokens
962
        if add_special_tokens:
963
964
            sequence = self.build_inputs_with_special_tokens(ids, pair_ids)
            token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids)
thomwolf's avatar
thomwolf committed
965
            special_tokens_mask = self.get_special_tokens_mask(ids, pair_ids)
966
        else:
thomwolf's avatar
thomwolf committed
967
968
            sequence = ids + pair_ids if pair else ids
            token_type_ids = [0] * len(ids) + ([1] * len(pair_ids) if pair else [])
thomwolf's avatar
thomwolf committed
969
970
971
            special_tokens_mask = [0] * (len(ids) + (len(pair_ids) if pair else 0))
        if return_special_tokens_mask:
            encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, pair_ids)
972

thomwolf's avatar
thomwolf committed
973
        # Prepare inputs as tensors if asked
thomwolf's avatar
thomwolf committed
974
        if return_tensors == 'tf' and is_tf_available():
thomwolf's avatar
thomwolf committed
975
976
            sequence = tf.constant([sequence])
            token_type_ids = tf.constant([token_type_ids])
thomwolf's avatar
thomwolf committed
977
        elif return_tensors == 'pt' and is_torch_available():
thomwolf's avatar
thomwolf committed
978
979
            sequence = torch.tensor([sequence])
            token_type_ids = torch.tensor([token_type_ids])
thomwolf's avatar
thomwolf committed
980
981
        elif return_tensors is not None:
            logger.warning("Unable to convert output to tensors format {}, PyTorch or TensorFlow is not available.".format(return_tensors))
thomwolf's avatar
thomwolf committed
982

thomwolf's avatar
thomwolf committed
983
        encoded_inputs["input_ids"] = sequence
thomwolf's avatar
thomwolf committed
984
985
        if return_token_type_ids:
            encoded_inputs["token_type_ids"] = token_type_ids
986

LysandreJik's avatar
LysandreJik committed
987
        if max_length and len(encoded_inputs["input_ids"]) > max_length:
988
            encoded_inputs["input_ids"] = encoded_inputs["input_ids"][:max_length]
thomwolf's avatar
thomwolf committed
989
990
991
992
            if return_token_type_ids:
                encoded_inputs["token_type_ids"] = encoded_inputs["token_type_ids"][:max_length]
            if return_special_tokens_mask:
                encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"][:max_length]
993

Lysandre's avatar
Lysandre committed
994
995
996
997
998
        if max_length is None and len(encoded_inputs["input_ids"]) > self.max_len:
            logger.warning("Token indices sequence length is longer than the specified maximum sequence length "
                           "for this model ({} > {}). Running this sequence through the model will result in "
                           "indexing errors".format(len(ids), self.max_len))
                           
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
        needs_to_be_padded = pad_to_max_length and (
            max_length and len(encoded_inputs["input_ids"]) < max_length
            or 
            max_length is None and len(encoded_inputs["input_ids"]) < self.max_len and self.max_len <= 10000
        )

        if pad_to_max_length and max_length is None and self.max_len > 10000:
            logger.warning("Sequence can't be padded as the maximum  ")

        if needs_to_be_padded:
            difference = (max_length if max_length is not None else self.max_len) - len(encoded_inputs["input_ids"])
1010

1011
            if self.padding_side == 'right':
1012
1013
1014
1015
1016
1017
1018
1019
                if return_attention_mask:
                    encoded_inputs["attention_mask"] = [1] * len(encoded_inputs["input_ids"]) + [0] * difference
                if return_token_type_ids:
                    encoded_inputs["token_type_ids"] = encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference
                if return_special_tokens_mask:
                    encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference
                encoded_inputs["input_ids"] = encoded_inputs["input_ids"] + [self.pad_token_id] * difference

1020
            elif self.padding_side == 'left':
1021
1022
1023
1024
1025
1026
1027
1028
1029
                if return_attention_mask:
                    encoded_inputs["attention_mask"] =  [0] * difference + [1] * len(encoded_inputs["input_ids"])
                if return_token_type_ids:
                    encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs["token_type_ids"]
                if return_special_tokens_mask:
                    encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
                encoded_inputs["input_ids"] = [self.pad_token_id] * difference + encoded_inputs["input_ids"]

            else:
1030
                raise ValueError("Invalid padding strategy:" + str(self.padding_side))
1031
            
1032
1033
        elif return_attention_mask:
            encoded_inputs["attention_mask"] = [1] * len(encoded_inputs["input_ids"])
1034
            
thomwolf's avatar
thomwolf committed
1035
        return encoded_inputs
thomwolf's avatar
thomwolf committed
1036

thomwolf's avatar
fixes  
thomwolf committed
1037
    def truncate_sequences(self, ids, pair_ids=None, num_tokens_to_remove=0, truncation_strategy='longest_first', stride=0):
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
        """Truncates a sequence pair in place to the maximum length.
            truncation_strategy: string selected in the following options:
                - 'longest_first' (default) Iteratively reduce the inputs sequence until the input is under max_length
                    starting from the longest one at each token (when there is a pair of input sequences).
                    Overflowing tokens only contains overflow from the first sequence.
                - 'only_first': Only truncate the first sequence. raise an error if the first sequence is shorter or equal to than num_tokens_to_remove.
                - 'only_second': Only truncate the second sequence
                - 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length)
        """
        if num_tokens_to_remove <= 0:
            return ids, pair_ids, []

        if truncation_strategy == 'longest_first':
            overflowing_tokens = []
            for _ in range(num_tokens_to_remove):
                if pair_ids is None or len(ids) > len(pair_ids):
thomwolf's avatar
fixes  
thomwolf committed
1054
                    overflowing_tokens = [ids[-1]] + overflowing_tokens
1055
1056
1057
                    ids = ids[:-1]
                else:
                    pair_ids = pair_ids[:-1]
thomwolf's avatar
fixes  
thomwolf committed
1058
1059
1060
            window_len = min(len(ids), stride)
            if window_len > 0:
                overflowing_tokens = ids[-window_len:] + overflowing_tokens
1061
1062
        elif truncation_strategy == 'only_first':
            assert len(ids) > num_tokens_to_remove
thomwolf's avatar
fixes  
thomwolf committed
1063
1064
            window_len = min(len(ids), stride + num_tokens_to_remove)
            overflowing_tokens = ids[-window_len:]
1065
1066
1067
            ids = ids[:-num_tokens_to_remove]
        elif truncation_strategy == 'only_second':
            assert pair_ids is not None and len(pair_ids) > num_tokens_to_remove
thomwolf's avatar
fixes  
thomwolf committed
1068
1069
            window_len = min(len(pair_ids), stride + num_tokens_to_remove)
            overflowing_tokens = pair_ids[-window_len:]
1070
1071
1072
1073
1074
1075
            pair_ids = pair_ids[:-num_tokens_to_remove]
        elif truncation_strategy == 'do_not_truncate':
            raise ValueError("Input sequence are too long for max_length. Please select a truncation strategy.")
        else:
            raise ValueError("Truncation_strategy should be selected in ['longest_first', 'only_first', 'only_second', 'do_not_truncate']")
        return (ids, pair_ids, overflowing_tokens)
1076

1077
1078
1079
    def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None):
        if token_ids_1 is None:
            return len(token_ids_0) * [0]
thomwolf's avatar
thomwolf committed
1080
        return [0] * len(token_ids_0) + [1] * len(token_ids_1)
1081

1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
    def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
        """
        Build model inputs from a sequence or a pair of sequence for sequence classification tasks
        by concatenating and adding special tokens.
        A RoBERTa sequence has the following format:
            single sequence: <s> X </s>
            pair of sequences: <s> A </s></s> B </s>
        """
        if token_ids_1 is None:
            return token_ids_0
LysandreJik's avatar
LysandreJik committed
1092
        return token_ids_0 + token_ids_1
1093

1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
    def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False):
        """
        Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
        special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.

        Args:
            token_ids_0: list of ids (must not contain special tokens)
            token_ids_1: Optional list of ids (must not contain special tokens), necessary when fetching sequence ids
                for sequence pairs
            already_has_special_tokens: (default False) Set to True if the token list is already formated with
                special tokens for the model

        Returns:
Lysandre's avatar
Lysandre committed
1107
            A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
1108
1109
        """
        return [0] * ((len(token_ids_1) if token_ids_1 else 0) + len(token_ids_0))
LysandreJik's avatar
LysandreJik committed
1110

1111
1112
1113
1114
1115
1116
1117
1118
    def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
        """ Converts a single index or a sequence of indices (integers) in a token "
            (resp.) a sequence of tokens (str/unicode), using the vocabulary and added tokens.

            Args:
                skip_special_tokens: Don't decode special tokens (self.all_special_tokens). Default: False
        """
        if isinstance(ids, int):
1119
1120
1121
1122
            if ids in self.added_tokens_decoder:
                return self.added_tokens_decoder[ids]
            else:
                return self._convert_id_to_token(ids)
1123
1124
        tokens = []
        for index in ids:
thomwolf's avatar
thomwolf committed
1125
            if skip_special_tokens and index in self.all_special_ids:
1126
1127
1128
1129
1130
1131
1132
1133
                continue
            if index in self.added_tokens_decoder:
                tokens.append(self.added_tokens_decoder[index])
            else:
                tokens.append(self._convert_id_to_token(index))
        return tokens

    def _convert_id_to_token(self, index):
thomwolf's avatar
thomwolf committed
1134
1135
        raise NotImplementedError

1136
1137
1138
1139
    def convert_tokens_to_string(self, tokens):
        """ Converts a sequence of tokens (string) in a single string.
            The most simple way to do it is ' '.join(self.convert_ids_to_tokens(token_ids))
            but we often want to remove sub-word tokenization artifacts at the same time.
1140
        """
1141
        return ' '.join(self.convert_ids_to_tokens(tokens))
1142
1143

    def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
LysandreJik's avatar
Doc  
LysandreJik committed
1144
1145
1146
        """
        Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary
        with options to remove special tokens and clean up tokenization spaces.
1147
        Similar to doing ``self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))``.
LysandreJik's avatar
LysandreJik committed
1148
1149
1150
1151
1152

        Args:
            token_ids: list of tokenized input ids. Can be obtained using the `encode` or `encode_plus` methods.
            skip_special_tokens: if set to True, will replace special tokens.
            clean_up_tokenization_spaces: if set to True, will clean up the tokenization spaces.
1153
1154
        """
        filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
thomwolf's avatar
thomwolf committed
1155
1156
1157

        # To avoid mixing byte-level and unicode for byte-level BPT
        # we need to build string separatly for added tokens and byte-level tokens
1158
        # cf. https://github.com/huggingface/transformers/issues/1133
thomwolf's avatar
thomwolf committed
1159
1160
1161
1162
1163
1164
1165
1166
1167
        sub_texts = []
        current_sub_text = []
        for token in filtered_tokens:
            if skip_special_tokens and token in self.all_special_ids:
                continue
            if token in self.added_tokens_encoder:
                if current_sub_text:
                    sub_texts.append(self.convert_tokens_to_string(current_sub_text))
                    current_sub_text = []
1168
                sub_texts.append(" " + token)
thomwolf's avatar
thomwolf committed
1169
1170
1171
1172
1173
            else:
                current_sub_text.append(token)
        if current_sub_text:
            sub_texts.append(self.convert_tokens_to_string(current_sub_text))
        text = ''.join(sub_texts)
1174

1175
1176
1177
        if clean_up_tokenization_spaces:
            clean_text = self.clean_up_tokenization(text)
            return clean_text
1178
        else:
1179
            return text
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200

    @property
    def special_tokens_map(self):
        """ A dictionary mapping special token class attribute (cls_token, unk_token...) to their
            values ('<unk>', '<cls>'...)
        """
        set_attr = {}
        for attr in self.SPECIAL_TOKENS_ATTRIBUTES:
            attr_value = getattr(self, "_" + attr)
            if attr_value:
                set_attr[attr] = attr_value
        return set_attr

    @property
    def all_special_tokens(self):
        """ List all the special tokens ('<unk>', '<cls>'...) mapped to class attributes
            (cls_token, unk_token...).
        """
        all_toks = []
        set_attr = self.special_tokens_map
        for attr_value in set_attr.values():
epwalsh's avatar
epwalsh committed
1201
            all_toks = all_toks + (list(attr_value) if isinstance(attr_value, (list, tuple)) else [attr_value])
1202
1203
1204
1205
1206
1207
1208
1209
1210
        all_toks = list(set(all_toks))
        return all_toks

    @property
    def all_special_ids(self):
        """ List the vocabulary indices of the special tokens ('<unk>', '<cls>'...) mapped to
            class attributes (cls_token, unk_token...).
        """
        all_toks = self.all_special_tokens
1211
        all_ids = self.convert_tokens_to_ids(all_toks)
1212
1213
        return all_ids

thomwolf's avatar
thomwolf committed
1214
1215
    @staticmethod
    def clean_up_tokenization(out_string):
1216
1217
        """ Clean up a list of simple English tokenization artifacts like spaces before punctuations and abreviated forms.
        """
thomwolf's avatar
thomwolf committed
1218
1219
1220
1221
        out_string = out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ','
                        ).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't"
                        ).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re")
        return out_string