"examples/vscode:/vscode.git/clone" did not exist on "691176283d81f5927218d81ff027b84097dd2a37"
tokenization_utils.py 48 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# coding=utf-8
# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for OpenAI GPT."""
from __future__ import (absolute_import, division, print_function,
                        unicode_literals)

import logging
import os
21
22
import json
import six
23
import copy
24
25
26
27
28
29
from io import open

from .file_utils import cached_path

logger = logging.getLogger(__name__)

30
31
SPECIAL_TOKENS_MAP_FILE = 'special_tokens_map.json'
ADDED_TOKENS_FILE = 'added_tokens.json'
32
TOKENIZER_CONFIG_FILE = 'tokenizer_config.json'
33
34

class PreTrainedTokenizer(object):
35
36
    """ Base class for all tokenizers.
    Handle all the shared methods for tokenization and special tokens as well as methods dowloading/caching/loading pretrained tokenizers as well as adding tokens to the vocabulary.
37

38
    This class also contain the added tokens in a unified way on top of all tokenizers so we don't have to handle the specific vocabulary augmentation methods of the various underlying dictionary structures (BPE, sentencepiece...).
39

40
41
42
43
44
    Class attributes (overridden by derived classes):

        - ``vocab_files_names``: a python ``dict`` with, as keys, the ``__init__`` keyword name of each vocabulary file required by the model, and as associated values, the filename for saving the associated file (string).
        - ``pretrained_vocab_files_map``: a python ``dict of dict`` the high-level keys being the ``__init__`` keyword name of each vocabulary file required by the model, the low-level being the `short-cut-names` (string) of the pretrained models with, as associated values, the `url` (string) to the associated pretrained vocabulary file.
        - ``max_model_input_sizes``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the pretrained models, and as associated values, the maximum length of the sequence inputs of this model, or None if the model has no maximum input size.
45
        - ``pretrained_init_configuration``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the pretrained models, and as associated values, a dictionnary of specific arguments to pass to the ``__init__``method of the tokenizer class for this pretrained model when loading the tokenizer with the ``from_pretrained()`` method.
46
47
48

    Parameters:

thomwolf's avatar
thomwolf committed
49
        - ``bos_token``: (`Optional`) string: a beginning of sentence token. Will be associated to ``self.bos_token`` and ``self.bos_token_id``
50

thomwolf's avatar
thomwolf committed
51
        - ``eos_token``: (`Optional`) string: an end of sentence token. Will be associated to ``self.eos_token`` and ``self.eos_token_id``
52

thomwolf's avatar
thomwolf committed
53
        - ``unk_token``: (`Optional`) string: an unknown token. Will be associated to ``self.unk_token`` and ``self.unk_token_id``
54

thomwolf's avatar
thomwolf committed
55
        - ``sep_token``: (`Optional`) string: a separation token (e.g. to separate context and query in an input sequence). Will be associated to ``self.sep_token`` and ``self.sep_token_id``
56

thomwolf's avatar
thomwolf committed
57
        - ``pad_token``: (`Optional`) string: a padding token. Will be associated to ``self.pad_token`` and ``self.pad_token_id``
58

thomwolf's avatar
thomwolf committed
59
        - ``cls_token``: (`Optional`) string: a classification token (e.g. to extract a summary of an input sequence leveraging self-attention along the full depth of the model). Will be associated to ``self.cls_token`` and ``self.cls_token_id``
60

thomwolf's avatar
thomwolf committed
61
        - ``mask_token``: (`Optional`) string: a masking token (e.g. when training a model with masked-language modeling). Will be associated to ``self.mask_token`` and ``self.mask_token_id``
62

thomwolf's avatar
thomwolf committed
63
        - ``additional_special_tokens``: (`Optional`) list: a list of additional special tokens. Adding all special tokens here ensure they won't be split by the tokenization process. Will be associated to ``self.additional_special_tokens`` and ``self.additional_special_tokens_ids``
64
65
66
    """
    vocab_files_names = {}
    pretrained_vocab_files_map = {}
67
    pretrained_init_configuration = {}
68
69
    max_model_input_sizes = {}

70
71
72
73
74
75
    SPECIAL_TOKENS_ATTRIBUTES = ["bos_token", "eos_token", "unk_token", "sep_token",
                                 "pad_token", "cls_token", "mask_token",
                                 "additional_special_tokens"]

    @property
    def bos_token(self):
76
        """ Beginning of sentence token (string). Log an error if used while not having been set. """
77
78
79
80
81
82
        if self._bos_token is None:
            logger.error("Using bos_token, but it is not set yet.")
        return self._bos_token

    @property
    def eos_token(self):
83
        """ End of sentence token (string). Log an error if used while not having been set. """
84
85
86
87
88
89
        if self._eos_token is None:
            logger.error("Using eos_token, but it is not set yet.")
        return self._eos_token

    @property
    def unk_token(self):
90
        """ Unknown token (string). Log an error if used while not having been set. """
91
92
93
94
95
96
        if self._unk_token is None:
            logger.error("Using unk_token, but it is not set yet.")
        return self._unk_token

    @property
    def sep_token(self):
97
        """ Separation token (string). E.g. separate context and query in an input sequence. Log an error if used while not having been set. """
98
99
100
101
102
103
        if self._sep_token is None:
            logger.error("Using sep_token, but it is not set yet.")
        return self._sep_token

    @property
    def pad_token(self):
104
        """ Padding token (string). Log an error if used while not having been set. """
105
106
107
108
109
110
        if self._pad_token is None:
            logger.error("Using pad_token, but it is not set yet.")
        return self._pad_token

    @property
    def cls_token(self):
111
        """ Classification token (string). E.g. to extract a summary of an input sequence leveraging self-attention along the full depth of the model. Log an error if used while not having been set. """
112
113
114
115
116
117
        if self._cls_token is None:
            logger.error("Using cls_token, but it is not set yet.")
        return self._cls_token

    @property
    def mask_token(self):
118
        """ Mask token (string). E.g. when training a model with masked-language modeling. Log an error if used while not having been set. """
119
120
121
122
123
124
        if self._mask_token is None:
            logger.error("Using mask_token, but it is not set yet.")
        return self._mask_token

    @property
    def additional_special_tokens(self):
125
        """ All the additional special tokens you may want to use (list of strings). Log an error if used while not having been set. """
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
        if self._additional_special_tokens is None:
            logger.error("Using additional_special_tokens, but it is not set yet.")
        return self._additional_special_tokens

    @bos_token.setter
    def bos_token(self, value):
        self._bos_token = value

    @eos_token.setter
    def eos_token(self, value):
        self._eos_token = value

    @unk_token.setter
    def unk_token(self, value):
        self._unk_token = value

    @sep_token.setter
    def sep_token(self, value):
        self._sep_token = value

    @pad_token.setter
    def pad_token(self, value):
        self._pad_token = value

    @cls_token.setter
    def cls_token(self, value):
        self._cls_token = value

    @mask_token.setter
    def mask_token(self, value):
        self._mask_token = value

    @additional_special_tokens.setter
    def additional_special_tokens(self, value):
        self._additional_special_tokens = value

162
163
164
    @property
    def bos_token_id(self):
        """ Id of the beginning of sentence token in the vocabulary. Log an error if used while not having been set. """
165
        return self.convert_tokens_to_ids(self.bos_token)
166
167
168
169

    @property
    def eos_token_id(self):
        """ Id of the end of sentence token in the vocabulary. Log an error if used while not having been set. """
170
        return self.convert_tokens_to_ids(self.eos_token)
171
172

    @property
maru0kun's avatar
maru0kun committed
173
    def unk_token_id(self):
174
        """ Id of the unknown token in the vocabulary. Log an error if used while not having been set. """
175
        return self.convert_tokens_to_ids(self.unk_token)
176
177
178
179

    @property
    def sep_token_id(self):
        """ Id of the separation token in the vocabulary. E.g. separate context and query in an input sequence. Log an error if used while not having been set. """
180
        return self.convert_tokens_to_ids(self.sep_token)
181
182
183
184

    @property
    def pad_token_id(self):
        """ Id of the padding token in the vocabulary. Log an error if used while not having been set. """
185
        return self.convert_tokens_to_ids(self.pad_token)
186
187
188
189

    @property
    def cls_token_id(self):
        """ Id of the classification token in the vocabulary. E.g. to extract a summary of an input sequence leveraging self-attention along the full depth of the model. Log an error if used while not having been set. """
190
        return self.convert_tokens_to_ids(self.cls_token)
191
192
193
194

    @property
    def mask_token_id(self):
        """ Id of the mask token in the vocabulary. E.g. when training a model with masked-language modeling. Log an error if used while not having been set. """
195
        return self.convert_tokens_to_ids(self.mask_token)
196
197
198
199

    @property
    def additional_special_tokens_ids(self):
        """ Ids of all the additional special tokens in the vocabulary (list of integers). Log an error if used while not having been set. """
200
        return self.convert_tokens_to_ids(self.additional_special_tokens)
201

202
203
204
205
206
207
208
209
210
211
212
    def __init__(self, max_len=None, **kwargs):
        self._bos_token = None
        self._eos_token = None
        self._unk_token = None
        self._sep_token = None
        self._pad_token = None
        self._cls_token = None
        self._mask_token = None
        self._additional_special_tokens = []

        self.max_len = max_len if max_len is not None else int(1e12)
213
214

        # Added tokens
215
216
217
        self.added_tokens_encoder = {}
        self.added_tokens_decoder = {}

218
219
220
221
        # inputs and kwargs for saving and re-loading (see ``from_pretrained`` and ``save_pretrained``)
        self.init_inputs = ()
        self.init_kwargs = {}

222
        for key, value in kwargs.items():
223
            if key in self.SPECIAL_TOKENS_ATTRIBUTES:
224
225
226
227
                if key == 'additional_special_tokens':
                    assert isinstance(value, (list, tuple)) and all(isinstance(t, str) or (six.PY2 and isinstance(t, unicode)) for t in value)
                else:
                    assert isinstance(value, str) or (six.PY2 and isinstance(value, unicode))
228
229
230
                setattr(self, key, value)


231
232
    @classmethod
    def from_pretrained(cls, *inputs, **kwargs):
LysandreJik's avatar
Doc  
LysandreJik committed
233
234
        r"""
        Instantiate a :class:`~pytorch_transformers.PreTrainedTokenizer` (or a derived class) from a predefined tokenizer.
235

LysandreJik's avatar
Doc  
LysandreJik committed
236
        Args:
237
238
239
240
241
242
243
244
245
            pretrained_model_name_or_path: either:

                - a string with the `shortcut name` of a predefined tokenizer to load from cache or download, e.g.: ``bert-base-uncased``.
                - a path to a `directory` containing vocabulary files required by the tokenizer, for instance saved using the :func:`~pytorch_transformers.PreTrainedTokenizer.save_pretrained` method, e.g.: ``./my_model_directory/``.
                - (not applicable to all derived classes) a path or url to a single saved vocabulary file if and only if the tokenizer only requires a single vocabulary file (e.g. Bert, XLNet), e.g.: ``./my_model_directory/vocab.txt``.

            cache_dir: (`optional`) string:
                Path to a directory in which a downloaded predefined tokenizer vocabulary files should be cached if the standard cache should not be used.

246
247
248
            force_download: (`optional`) boolean, default False:
                Force to (re-)download the vocabulary files and override the cached versions if they exists.

249
250
251
252
            proxies: (`optional`) dict, default None:
                A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
                The proxies are used on each request.

253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
            inputs: (`optional`) positional arguments: will be passed to the Tokenizer ``__init__`` method.

            kwargs: (`optional`) keyword arguments: will be passed to the Tokenizer ``__init__`` method. Can be used to set special tokens like ``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``, ``additional_special_tokens``. See parameters in the doc string of :class:`~pytorch_transformers.PreTrainedTokenizer` for details.

        Examples::

            # We can't instantiate directly the base class `PreTrainedTokenizer` so let's show our examples on a derived class: BertTokenizer

            # Download vocabulary from S3 and cache.
            tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

            # If vocabulary files are in a directory (e.g. tokenizer was saved using `save_pretrained('./test/saved_model/')`)
            tokenizer = BertTokenizer.from_pretrained('./test/saved_model/')

            # If the tokenizer uses a single vocabulary file, you can point directly to this file
            tokenizer = BertTokenizer.from_pretrained('./test/saved_model/my_vocab.txt')

            # You can link tokens to special vocabulary when instantiating
            tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', unk_token='<unk>')
            # You should be sure '<unk>' is in the vocabulary when doing that.
            # Otherwise use tokenizer.add_special_tokens({'unk_token': '<unk>'}) instead)
            assert tokenizer.unk_token == '<unk>'

        """
277
278
        return cls._from_pretrained(*inputs, **kwargs)

279

280
    @classmethod
281
    def _from_pretrained(cls, pretrained_model_name_or_path, *init_inputs, **kwargs):
thomwolf's avatar
thomwolf committed
282
        cache_dir = kwargs.pop('cache_dir', None)
283
        force_download = kwargs.pop('force_download', False)
284
        proxies = kwargs.pop('proxies', None)
thomwolf's avatar
thomwolf committed
285

286
287
        s3_models = list(cls.max_model_input_sizes.keys())
        vocab_files = {}
288
        init_configuration = {}
289
        if pretrained_model_name_or_path in s3_models:
thomwolf's avatar
thomwolf committed
290
            # Get the vocabulary from AWS S3 bucket
291
292
            for file_id, map_list in cls.pretrained_vocab_files_map.items():
                vocab_files[file_id] = map_list[pretrained_model_name_or_path]
293
294
            if cls.pretrained_init_configuration and pretrained_model_name_or_path in cls.pretrained_init_configuration:
                init_configuration = cls.pretrained_init_configuration[pretrained_model_name_or_path]
295
        else:
thomwolf's avatar
thomwolf committed
296
            # Get the vocabulary from local files
297
298
299
300
301
            logger.info(
                "Model name '{}' not found in model shortcut name list ({}). "
                "Assuming '{}' is a path or url to a directory containing tokenizer files.".format(
                    pretrained_model_name_or_path, ', '.join(s3_models),
                    pretrained_model_name_or_path))
thomwolf's avatar
thomwolf committed
302
303
304

            # Look for the tokenizer main vocabulary files
            for file_id, file_name in cls.vocab_files_names.items():
305
                if os.path.isdir(pretrained_model_name_or_path):
thomwolf's avatar
thomwolf committed
306
                    # If a directory is provided we look for the standard filenames
307
308
                    full_file_name = os.path.join(pretrained_model_name_or_path, file_name)
                else:
thomwolf's avatar
thomwolf committed
309
                    # If a path to a file is provided we use it (will only work for non-BPE tokenizer using a single vocabulary file)
310
311
                    full_file_name = pretrained_model_name_or_path
                if not os.path.exists(full_file_name):
312
                    logger.info("Didn't find file {}. We won't load it.".format(full_file_name))
313
314
                    full_file_name = None
                vocab_files[file_id] = full_file_name
thomwolf's avatar
thomwolf committed
315
316

            # Look for the additional tokens files
317
318
319
320
            additional_files_names = {'added_tokens_file': ADDED_TOKENS_FILE,
                                      'special_tokens_map_file': SPECIAL_TOKENS_MAP_FILE,
                                      'tokenizer_config_file': TOKENIZER_CONFIG_FILE,
                                      }
thomwolf's avatar
thomwolf committed
321
322
323
324
325
326

            # If a path to a file was provided, get the parent directory
            saved_directory = pretrained_model_name_or_path
            if os.path.exists(saved_directory) and not os.path.isdir(saved_directory):
                saved_directory = os.path.dirname(saved_directory)

327
            for file_id, file_name in additional_files_names.items():
thomwolf's avatar
thomwolf committed
328
329
330
331
332
333
                full_file_name = os.path.join(saved_directory, file_name)
                if not os.path.exists(full_file_name):
                    logger.info("Didn't find file {}. We won't load it.".format(full_file_name))
                    full_file_name = None
                vocab_files[file_id] = full_file_name

334
335
336
337
338
339
340
341
            if all(full_file_name is None for full_file_name in vocab_files.values()):
                logger.error(
                    "Model name '{}' was not found in model name list ({}). "
                    "We assumed '{}' was a path or url but couldn't find tokenizer files"
                    "at this path or url.".format(
                        pretrained_model_name_or_path, ', '.join(s3_models),
                        pretrained_model_name_or_path, ))
                return None
342
343

        # Get files from url, cache, or disk depending on the case
344
345
346
347
348
349
        try:
            resolved_vocab_files = {}
            for file_id, file_path in vocab_files.items():
                if file_path is None:
                    resolved_vocab_files[file_id] = None
                else:
350
                    resolved_vocab_files[file_id] = cached_path(file_path, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
351
        except EnvironmentError as e:
352
353
354
355
356
357
358
359
360
            if pretrained_model_name_or_path in s3_models:
                logger.error("Couldn't reach server to download vocabulary.")
            else:
                logger.error(
                    "Model name '{}' was not found in model name list ({}). "
                    "We assumed '{}' was a path or url but couldn't find files {} "
                    "at this path or url.".format(
                        pretrained_model_name_or_path, ', '.join(s3_models),
                        pretrained_model_name_or_path, str(vocab_files.keys())))
361
            raise e
362
363
364
365
366
367
368
369

        for file_id, file_path in vocab_files.items():
            if file_path == resolved_vocab_files[file_id]:
                logger.info("loading file {}".format(file_path))
            else:
                logger.info("loading file {} from cache at {}".format(
                    file_path, resolved_vocab_files[file_id]))

370
371
372
373
374
        # Prepare tokenizer initialization kwargs
        # Did we saved some inputs and kwargs to reload ?
        tokenizer_config_file = resolved_vocab_files.pop('tokenizer_config_file', None)
        if tokenizer_config_file is not None:
            init_kwargs = json.load(open(tokenizer_config_file, encoding="utf-8"))
375
            saved_init_inputs = init_kwargs.pop('init_inputs', ())
376
377
378
379
380
381
            if not init_inputs:
                init_inputs = saved_init_inputs
        else:
            init_kwargs = init_configuration

        # Update with newly provided kwargs
382
383
        init_kwargs.update(kwargs)

384
        # Set max length if needed
385
386
387
388
        if pretrained_model_name_or_path in cls.max_model_input_sizes:
            # if we're using a pretrained model, ensure the tokenizer
            # wont index sequences longer than the number of positional embeddings
            max_len = cls.max_model_input_sizes[pretrained_model_name_or_path]
389
            if max_len is not None and isinstance(max_len, (int, float)):
390
                init_kwargs['max_len'] = min(init_kwargs.get('max_len', int(1e12)), max_len)
391

392
        # Merge resolved_vocab_files arguments in init_kwargs.
393
394
        added_tokens_file = resolved_vocab_files.pop('added_tokens_file', None)
        special_tokens_map_file = resolved_vocab_files.pop('special_tokens_map_file', None)
thomwolf's avatar
thomwolf committed
395
        for args_name, file_path in resolved_vocab_files.items():
396
397
            if args_name not in init_kwargs:
                init_kwargs[args_name] = file_path
398
399
400
        if special_tokens_map_file is not None:
            special_tokens_map = json.load(open(special_tokens_map_file, encoding="utf-8"))
            for key, value in special_tokens_map.items():
401
402
                if key not in init_kwargs:
                    init_kwargs[key] = value
thomwolf's avatar
thomwolf committed
403

404
        # Instantiate tokenizer.
405
406
407
408
409
        tokenizer = cls(*init_inputs, **init_kwargs)

        # Save inputs and kwargs for saving and re-loading with ``save_pretrained``
        tokenizer.init_inputs = init_inputs
        tokenizer.init_kwargs = init_kwargs
410

411
412
        # Add supplementary tokens.
        if added_tokens_file is not None:
thomwolf's avatar
thomwolf committed
413
            added_tok_encoder = json.load(open(added_tokens_file, encoding="utf-8"))
414
415
416
417
            added_tok_decoder = {v:k for k, v in added_tok_encoder.items()}
            tokenizer.added_tokens_encoder.update(added_tok_encoder)
            tokenizer.added_tokens_decoder.update(added_tok_decoder)

418
419
        return tokenizer

thomwolf's avatar
thomwolf committed
420

421
    def save_pretrained(self, save_directory):
422
423
424
425
426
427
428
        """ Save the tokenizer vocabulary files together with:
                - added tokens,
                - special-tokens-to-class-attributes-mapping,
                - tokenizer instantiation positional and keywords inputs (e.g. do_lower_case for Bert).

            This won't save modifications other than (added tokens and special token mapping) you may have
            applied to the tokenizer after the instantion (e.g. modifying tokenizer.do_lower_case after creation).
429
430

            This method make sure the full tokenizer can then be re-loaded using the :func:`~pytorch_transformers.PreTrainedTokenizer.from_pretrained` class method.
431
432
433
434
435
436
437
        """
        if not os.path.isdir(save_directory):
            logger.error("Saving directory ({}) should be a directory".format(save_directory))
            return

        special_tokens_map_file = os.path.join(save_directory, SPECIAL_TOKENS_MAP_FILE)
        added_tokens_file = os.path.join(save_directory, ADDED_TOKENS_FILE)
438
439
440
441
        tokenizer_config_file = os.path.join(save_directory, TOKENIZER_CONFIG_FILE)

        tokenizer_config = copy.deepcopy(self.init_kwargs)
        tokenizer_config['init_inputs'] = copy.deepcopy(self.init_inputs)
442
443
        for file_id in self.vocab_files_names.keys():
            tokenizer_config.pop(file_id, None)
444
445
446

        with open(tokenizer_config_file, 'w', encoding='utf-8') as f:
            f.write(json.dumps(tokenizer_config, ensure_ascii=False))
447
448
449
450
451

        with open(special_tokens_map_file, 'w', encoding='utf-8') as f:
            f.write(json.dumps(self.special_tokens_map, ensure_ascii=False))

        with open(added_tokens_file, 'w', encoding='utf-8') as f:
thomwolf's avatar
thomwolf committed
452
            if self.added_tokens_encoder:
453
                out_str = json.dumps(self.added_tokens_encoder, ensure_ascii=False)
thomwolf's avatar
thomwolf committed
454
455
456
            else:
                out_str = u"{}"
            f.write(out_str)
457
458
459
460
461
462
463

        vocab_files = self.save_vocabulary(save_directory)

        return vocab_files + (special_tokens_map_file, added_tokens_file)


    def save_vocabulary(self, save_directory):
464
        """ Save the tokenizer vocabulary to a directory. This method does *NOT* save added tokens
465
            and special token mappings.
466
467

            Please use :func:`~pytorch_transformers.PreTrainedTokenizer.save_pretrained` `()` to save the full Tokenizer state if you want to reload it using the :func:`~pytorch_transformers.PreTrainedTokenizer.from_pretrained` class method.
468
        """
thomwolf's avatar
thomwolf committed
469
470
        raise NotImplementedError

471
472

    def vocab_size(self):
473
        """ Size of the base vocabulary (without the added tokens) """
thomwolf's avatar
thomwolf committed
474
475
        raise NotImplementedError

476
477

    def __len__(self):
478
        """ Size of the full vocabulary with the added tokens """
479
480
481
482
        return self.vocab_size + len(self.added_tokens_encoder)


    def add_tokens(self, new_tokens):
LysandreJik's avatar
Doc  
LysandreJik committed
483
484
        """
        Add a list of new tokens to the tokenizer class. If the new tokens are not in the
485
486
        vocabulary, they are added to it with indices starting from length of the current vocabulary.

LysandreJik's avatar
Doc  
LysandreJik committed
487
488
        Args:
            new_tokens: list of string. Each string is a token to add. Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them).
489

LysandreJik's avatar
Doc  
LysandreJik committed
490
491
        Returns:
            Number of tokens added to the vocabulary.
492
493
494
495
496
497
498
499
500
501

        Examples::

            # Let's see how to increase the vocabulary of Bert model and tokenizer
            tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
            model = BertModel.from_pretrained('bert-base-uncased')

            num_added_toks = tokenizer.add_tokens(['new_tok1', 'my_new-tok2'])
            print('We have added', num_added_toks, 'tokens')
            model.resize_token_embeddings(len(tokenizer))  # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
502
503
504
505
506
507
        """
        if not new_tokens:
            return 0

        to_add_tokens = []
        for token in new_tokens:
508
            assert isinstance(token, str) or (six.PY2 and isinstance(token, unicode))
thomwolf's avatar
thomwolf committed
509
510
            if token != self.unk_token and \
                    self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token):
511
512
513
514
515
516
517
518
519
520
                to_add_tokens.append(token)
                logger.info("Adding %s to the vocabulary", token)

        added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(to_add_tokens))
        added_tok_decoder = {v:k for k, v in added_tok_encoder.items()}
        self.added_tokens_encoder.update(added_tok_encoder)
        self.added_tokens_decoder.update(added_tok_decoder)

        return len(to_add_tokens)

521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
    def num_added_tokens(self, pair=False):
        """
        Returns the number of added tokens when encoding a sequence with special tokens.

        Note:
            This encodes inputs and checks the number of added tokens, and is therefore not efficient. Do not put this
            inside your training loop.

        Args:
            pair: Returns the number of added tokens in the case of a sequence pair if set to True, returns the
                number of added tokens in the case of a single sequence if set to False.

        Returns:
            Number of tokens added to sequences
        """

        if pair:
538
            initial_tokens_len = len(self.encode("This is a sequence") + self.encode("This is another"))
LysandreJik's avatar
LysandreJik committed
539
            final_tokens_len = len(self.encode("This is a sequence", "This is another", add_special_tokens=True))
540
541
542
543
544
        else:
            initial_tokens_len = len(self.encode("This is a sequence"))
            final_tokens_len = len(self.encode("This is a sequence", add_special_tokens=True))

        return final_tokens_len - initial_tokens_len
545
546

    def add_special_tokens(self, special_tokens_dict):
LysandreJik's avatar
Doc  
LysandreJik committed
547
548
549
550
        """
        Add a dictionary of special tokens (eos, pad, cls...) to the encoder and link them
        to class attributes. If special tokens are NOT in the vocabulary, they are added
        to it (indexed starting from the last index of the current vocabulary).
551

thomwolf's avatar
thomwolf committed
552
553
554
555
556
557
558
        Using `add_special_tokens` will ensure your special tokens can be used in several ways:

        - special tokens are carefully handled by the tokenizer (they are never split)
        - you can easily refer to special tokens using tokenizer class attributes like `tokenizer.cls_token`. This makes it easy to develop model-agnostic training and fine-tuning scripts.

        When possible, special tokens are already registered for provided pretrained models (ex: BertTokenizer cls_token is already registered to be '[CLS]' and XLM's one is also registered to be '</s>')

LysandreJik's avatar
Doc  
LysandreJik committed
559
560
561
562
        Args:
            special_tokens_dict: dict of string. Keys should be in the list of predefined special attributes:
                [``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``,
                ``additional_special_tokens``].
563

LysandreJik's avatar
Doc  
LysandreJik committed
564
                Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them).
565

LysandreJik's avatar
Doc  
LysandreJik committed
566
567
        Returns:
            Number of tokens added to the vocabulary.
568
569
570
571
572
573
574
575
576
577
578
579
580
581

        Examples::

            # Let's see how to add a new classification token to GPT-2
            tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
            model = GPT2Model.from_pretrained('gpt2')

            special_tokens_dict = {'cls_token': '<CLS>'}

            num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
            print('We have added', num_added_toks, 'tokens')
            model.resize_token_embeddings(len(tokenizer))  # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.

            assert tokenizer.cls_token == '<CLS>'
582
583
584
585
        """
        if not special_tokens_dict:
            return 0

586
        added_tokens = 0
587
        for key, value in special_tokens_dict.items():
588
            assert key in self.SPECIAL_TOKENS_ATTRIBUTES
589
590
591
592
593
594
            if key == 'additional_special_tokens':
                assert isinstance(value, (list, tuple)) and all(isinstance(t, str) or (six.PY2 and isinstance(t, unicode)) for t in value)
                added_tokens += self.add_tokens(value)
            else:
                assert isinstance(value, str) or (six.PY2 and isinstance(value, unicode))
                added_tokens += self.add_tokens([value])
595
596
597
            logger.info("Assigning %s to the %s key of the tokenizer", value, key)
            setattr(self, key, value)

598
        return added_tokens
599
600
601
602
603
604
605
606

    def tokenize(self, text, **kwargs):
        """ Converts a string in a sequence of tokens (string), using the tokenizer.
            Split in words for word-based vocabulary or sub-words for sub-word-based
            vocabularies (BPE/SentencePieces/WordPieces).

            Take care of added tokens.
        """
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
        def split_on_token(tok, text):
            result = []
            split_text = text.split(tok)
            for i, sub_text in enumerate(split_text):
                sub_text = sub_text.strip()
                if i == 0 and not sub_text:
                    result += [tok]
                elif i == len(split_text) - 1:
                    if sub_text:
                        result += [sub_text]
                    else:
                        pass
                else:
                    if sub_text:
                        result += [sub_text]
                    result += [tok]
            return result

625
626
627
628
629
        def split_on_tokens(tok_list, text):
            if not text:
                return []
            if not tok_list:
                return self._tokenize(text, **kwargs)
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645

            tokenized_text = []
            text_list = [text]
            for tok in tok_list:
                tokenized_text = []
                for sub_text in text_list:
                    if sub_text not in self.added_tokens_encoder \
                            and sub_text not in self.all_special_tokens:
                        tokenized_text += split_on_token(tok, sub_text)
                    else:
                        tokenized_text += [sub_text]
                text_list = tokenized_text

            return sum((self._tokenize(token, **kwargs) if token not \
                    in self.added_tokens_encoder and token not in self.all_special_tokens \
                    else [token] for token in tokenized_text), [])
646

647
        added_tokens = list(self.added_tokens_encoder.keys()) + self.all_special_tokens
648
649
650
651
652
653
654
655
        tokenized_text = split_on_tokens(added_tokens, text)
        return tokenized_text

    def _tokenize(self, text, **kwargs):
        """ Converts a string in a sequence of tokens (string), using the tokenizer.
            Split in words for word-based vocabulary or sub-words for sub-word-based
            vocabularies (BPE/SentencePieces/WordPieces).

656
            Do NOT take care of added tokens.
657
        """
thomwolf's avatar
thomwolf committed
658
659
        raise NotImplementedError

660
    def convert_tokens_to_ids(self, tokens):
661
662
        """ Converts a single token, or a sequence of tokens, (str/unicode) in a single integer id
            (resp. a sequence of ids), using the vocabulary.
663
        """
664
665
666
        if tokens is None:
            return None

667
        if isinstance(tokens, str) or (six.PY2 and isinstance(tokens, unicode)):
668
            return self._convert_token_to_id_with_added_voc(tokens)
669
670
671

        ids = []
        for token in tokens:
672
            ids.append(self._convert_token_to_id_with_added_voc(token))
673
674
675
676
677
678
        if len(ids) > self.max_len:
            logger.warning("Token indices sequence length is longer than the specified maximum sequence length "
                           "for this model ({} > {}). Running this sequence through the model will result in "
                           "indexing errors".format(len(ids), self.max_len))
        return ids

679
    def _convert_token_to_id_with_added_voc(self, token):
680
681
682
        if token is None:
            return None

683
684
685
686
687
        if token in self.added_tokens_encoder:
            return self.added_tokens_encoder[token]
        return self._convert_token_to_id(token)

    def _convert_token_to_id(self, token):
thomwolf's avatar
thomwolf committed
688
689
        raise NotImplementedError

690
    def encode(self, text, text_pair=None, add_special_tokens=False, **kwargs):
LysandreJik's avatar
Doc  
LysandreJik committed
691
692
        """
        Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
693
694
695
696

        Same as doing ``self.convert_tokens_to_ids(self.tokenize(text))``.

        Args:
LysandreJik's avatar
LysandreJik committed
697
698
699
700
701
702
            text: The first sequence to be encoded. This can be a string, a list of strings (tokenized string using
                the `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids`
                method)
            text_pair: Optional second sequence to be encoded. This can be a string, a list of strings (tokenized
                string using the `tokenize` method) or a list of integers (tokenized string ids using the
                `convert_tokens_to_ids` method)
703
704
            add_special_tokens: if set to ``True``, the sequences will be encoded with the special tokens relative
                to their model.
LysandreJik's avatar
LysandreJik committed
705
            **kwargs: passed to the `self.tokenize()` method
706
        """
LysandreJik's avatar
LysandreJik committed
707
        return self.encode_plus(text, text_pair, add_special_tokens, **kwargs)["input_ids"]
708

709
710
711
712
    def encode_plus(self,
                    text,
                    text_pair=None,
                    add_special_tokens=False,
LysandreJik's avatar
LysandreJik committed
713
                    output_token_type=False,
714
715
                    max_length=None,
                    stride=0,
LysandreJik's avatar
LysandreJik committed
716
                    truncate_first_sequence=True,
717
                    **kwargs):
718
        """
LysandreJik's avatar
LysandreJik committed
719
720
        Returns a dictionary containing the encoded sequence or sequence pair. Other values can be returned by this
        method: the mask for sequence classification and the overflowing elements if a ``max_length`` is specified.
LysandreJik's avatar
Doc  
LysandreJik committed
721
722

        Args:
LysandreJik's avatar
LysandreJik committed
723
724
725
726
727
728
            text: The first sequence to be encoded. This can be a string, a list of strings (tokenized string using
                the `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids`
                method)
            text_pair: Optional second sequence to be encoded. This can be a string, a list of strings (tokenized
                string using the `tokenize` method) or a list of integers (tokenized string ids using the
                `convert_tokens_to_ids` method)
LysandreJik's avatar
Doc  
LysandreJik committed
729
730
            add_special_tokens: if set to ``True``, the sequences will be encoded with the special tokens relative
                to their model.
LysandreJik's avatar
LysandreJik committed
731
            output_token_type: if set to ``True``, returns the text pair corresponding mask with 0 for the first sequence,
732
                and 1 for the second.
733
            max_length: if set to a number, will limit the total sequence returned so that it has a maximum length.
LysandreJik's avatar
LysandreJik committed
734
735
736
                If there are overflowing tokens, those will be added to the returned dictionary
            stride: if set to a number along with max_length, the overflowing tokens returned will contain some tokens
                from the main sequence returned. The value of this argument defined the number of additional tokens.
LysandreJik's avatar
LysandreJik committed
737
            truncate_first_sequence: if there is a specified max_length, this flag will choose which sequence
738
                will be truncated.
thomwolf's avatar
thomwolf committed
739
            **kwargs: passed to the `self.tokenize()` method
740
        """
741
742
743

        information = {}

LysandreJik's avatar
LysandreJik committed
744
745
746
        def get_input_ids(text):
            if isinstance(text, six.string_types):
                input_ids = self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
LysandreJik's avatar
LysandreJik committed
747
            elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], six.string_types):
LysandreJik's avatar
LysandreJik committed
748
749
750
751
752
753
754
755
                input_ids = self.convert_tokens_to_ids(text)
            elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
                input_ids = text
            else:
                raise ValueError("Input is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers.")

            return input_ids

756
        if text_pair is None:
LysandreJik's avatar
LysandreJik committed
757
758
            sequence_tokens = get_input_ids(text)

759
            if add_special_tokens:
LysandreJik's avatar
LysandreJik committed
760
                information = self.prepare_for_model(sequence_tokens, max_length=max_length, stride=stride)
761
762
            else:
                if max_length:
LysandreJik's avatar
LysandreJik committed
763
                    information["overflowing_tokens"] = sequence_tokens[max_length - stride:]
764
                    sequence_tokens = sequence_tokens[:max_length]
LysandreJik's avatar
LysandreJik committed
765
                information["input_ids"] = sequence_tokens
766

LysandreJik's avatar
LysandreJik committed
767
768
            if output_token_type:
                information["output_token_type"] = [0] * len(information["input_ids"])
769
        else:
LysandreJik's avatar
LysandreJik committed
770
771
            first_sentence_tokens = get_input_ids(text)
            second_sentence_tokens = get_input_ids(text_pair)
772
773

            if add_special_tokens:
774
                information = self.prepare_pair_for_model(
775
                    first_sentence_tokens,
776
                    second_sentence_tokens,
LysandreJik's avatar
LysandreJik committed
777
778
779
                    max_length=max_length,
                    truncate_first_sequence=truncate_first_sequence,
                    stride=stride
780
781
                )

LysandreJik's avatar
LysandreJik committed
782
783
                if output_token_type:
                    information["output_token_type"] = self.create_mask_from_sequences(text, text_pair)
784
785
786
787
788
            else:
                logger.warning("No special tokens were added. The two sequences have been concatenated.")
                sequence = first_sentence_tokens + second_sentence_tokens

                if max_length:
LysandreJik's avatar
LysandreJik committed
789
                    information["overflowing_tokens"] = sequence[max_length - stride:]
790
                    sequence = sequence[:max_length]
LysandreJik's avatar
LysandreJik committed
791
792
                if output_token_type:
                    information["output_token_type"] = [0] * len(sequence)
793

LysandreJik's avatar
LysandreJik committed
794
                information["input_ids"] = sequence
795
796
797

        return information

798
    def prepare_for_model(self, ids, max_length=None, stride=0):
LysandreJik's avatar
LysandreJik committed
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
        """
        Prepares a list of tokenized input ids so that it can be used by the model. It adds special tokens, truncates
        sequences if overflowing while taking into account the special tokens and manages a window stride for
        overflowing tokens

        Args:
            ids: list of tokenized input ids. Can be obtained from a string by chaining the
                `tokenize` and `convert_tokens_to_ids` methods.
            max_length: maximum length of the returned list. Will truncate by taking into account the special tokens.
            stride: window stride for overflowing tokens. Can be useful for edge effect removal when using sequential
                list of inputs.

        Return:
            a dictionary containing the `input_ids` as well as the `overflowing_tokens` if a `max_length` was given.
        """
814
815
        information = {}
        if max_length:
LysandreJik's avatar
LysandreJik committed
816
            n_added_tokens = self.num_added_tokens()
817
818
            information["overflowing_tokens"] = ids[max_length - n_added_tokens - stride:]
            ids = ids[:max_length - n_added_tokens]
LysandreJik's avatar
LysandreJik committed
819
        information["input_ids"] = self.add_special_tokens_single_sequence(ids)
820
821
822

        return information

LysandreJik's avatar
LysandreJik committed
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
    def prepare_pair_for_model(self, ids_0, ids_1, max_length=None, truncate_first_sequence=True, stride=0):
        """
        Prepares a list of tokenized input ids pair so that it can be used by the model. It adds special tokens,
        truncates sequences if overflowing while taking into account the special tokens and manages a window stride for
        overflowing tokens

        Args:
            ids_0: list of tokenized input ids. Can be obtained from a string by chaining the
                `tokenize` and `convert_tokens_to_ids` methods.
            ids_1: second list of tokenized input ids. Can be obtained from a string by chaining the
                `tokenize` and `convert_tokens_to_ids` methods.
            max_length: maximum length of the returned list. Will truncate by taking into account the special tokens.
            truncate_first_sequence: if set to `True`, alongside a specified `max_length`, will truncate the first
                sequence if the total size is superior than the specified `max_length`. If set to `False`, will
                truncate the second sequence instead.
            stride: window stride for overflowing tokens. Can be useful for edge effect removal when using sequential
                list of inputs.

        Return:
            a dictionary containing the `input_ids` as well as the `overflowing_tokens` if a `max_length` was given.
        """
844
845
846
847
        f_len, s_len = len(ids_0), len(ids_1)
        information = {}

        if max_length:
LysandreJik's avatar
LysandreJik committed
848
            n_added_tokens = self.num_added_tokens(pair=True)
849
850
851
852
853
            if len(ids_0) + n_added_tokens >= max_length:
                logger.warning(
                    "The first sequence is longer than the maximum specified length. This sequence will not be truncated.")
            else:
                if f_len + s_len + self.num_added_tokens(pair=True) > max_length:
LysandreJik's avatar
LysandreJik committed
854
                    if truncate_first_sequence:
855
856
                        information["overflowing_tokens"] = ids_0[max_length - s_len - n_added_tokens - stride:]
                        ids_0 = ids_0[:max_length - s_len - n_added_tokens]
LysandreJik's avatar
LysandreJik committed
857
858
859
                    else:
                        information["overflowing_tokens"] = ids_1[max_length - f_len - n_added_tokens - stride:]
                        ids_1 = ids_1[:max_length - f_len - n_added_tokens]
860
861

        sequence = self.add_special_tokens_sequence_pair(ids_0, ids_1)
LysandreJik's avatar
LysandreJik committed
862
        information["input_ids"] = sequence
863
864
865

        return information

866
867
868
869
    def create_mask_from_sequences(self, sequence_0, sequence_1):
        logger.warning("This tokenizer does not make use of special tokens.")
        return [0] * len(self.encode(sequence_0)) + [1] * len(self.encode(sequence_1))

870
    def add_special_tokens_single_sequence(self, token_ids):
LysandreJik's avatar
LysandreJik committed
871
872
        logger.warning("This tokenizer does not make use of special tokens. The sequence has been returned with no modification.")
        return token_ids
873

874
    def add_special_tokens_sequence_pair(self, token_ids_0, token_ids_1):
LysandreJik's avatar
LysandreJik committed
875
876
        logger.warning("This tokenizer does not make use of special tokens. The two sequences have been concatenated.")
        return token_ids_0 + token_ids_1
877

878
879
880
881
882
883
884
885
    def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
        """ Converts a single index or a sequence of indices (integers) in a token "
            (resp.) a sequence of tokens (str/unicode), using the vocabulary and added tokens.

            Args:
                skip_special_tokens: Don't decode special tokens (self.all_special_tokens). Default: False
        """
        if isinstance(ids, int):
886
887
888
889
            if ids in self.added_tokens_decoder:
                return self.added_tokens_decoder[ids]
            else:
                return self._convert_id_to_token(ids)
890
891
        tokens = []
        for index in ids:
thomwolf's avatar
thomwolf committed
892
            if skip_special_tokens and index in self.all_special_ids:
893
894
895
896
897
898
899
900
                continue
            if index in self.added_tokens_decoder:
                tokens.append(self.added_tokens_decoder[index])
            else:
                tokens.append(self._convert_id_to_token(index))
        return tokens

    def _convert_id_to_token(self, index):
thomwolf's avatar
thomwolf committed
901
902
        raise NotImplementedError

903
904
905
906
    def convert_tokens_to_string(self, tokens):
        """ Converts a sequence of tokens (string) in a single string.
            The most simple way to do it is ' '.join(self.convert_ids_to_tokens(token_ids))
            but we often want to remove sub-word tokenization artifacts at the same time.
907
        """
908
        return ' '.join(self.convert_ids_to_tokens(tokens))
909
910

    def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
LysandreJik's avatar
Doc  
LysandreJik committed
911
912
913
        """
        Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary
        with options to remove special tokens and clean up tokenization spaces.
914
        Similar to doing ``self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))``.
915
916
        """
        filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
thomwolf's avatar
thomwolf committed
917
918
919
920
921
922
923
924
925
926
927
928
929

        # To avoid mixing byte-level and unicode for byte-level BPT
        # we need to build string separatly for added tokens and byte-level tokens
        # cf. https://github.com/huggingface/pytorch-transformers/issues/1133
        sub_texts = []
        current_sub_text = []
        for token in filtered_tokens:
            if skip_special_tokens and token in self.all_special_ids:
                continue
            if token in self.added_tokens_encoder:
                if current_sub_text:
                    sub_texts.append(self.convert_tokens_to_string(current_sub_text))
                    current_sub_text = []
930
                sub_texts.append(" " + token)
thomwolf's avatar
thomwolf committed
931
932
933
934
935
            else:
                current_sub_text.append(token)
        if current_sub_text:
            sub_texts.append(self.convert_tokens_to_string(current_sub_text))
        text = ''.join(sub_texts)
936

937
938
939
        if self._sep_token is not None and self._sep_token in text:
            text = text.replace(self._cls_token, self._sep_token)
            split_text = list(filter(lambda sentence: len(sentence) > 0, text.split(self._sep_token)))
940
941
942
943
944
945
946
947
948
949
950
            if clean_up_tokenization_spaces:
                clean_text = [self.clean_up_tokenization(text) for text in split_text]
                return clean_text
            else:
                return split_text
        else:
            if clean_up_tokenization_spaces:
                clean_text = self.clean_up_tokenization(text)
                return clean_text
            else:
                return text
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971

    @property
    def special_tokens_map(self):
        """ A dictionary mapping special token class attribute (cls_token, unk_token...) to their
            values ('<unk>', '<cls>'...)
        """
        set_attr = {}
        for attr in self.SPECIAL_TOKENS_ATTRIBUTES:
            attr_value = getattr(self, "_" + attr)
            if attr_value:
                set_attr[attr] = attr_value
        return set_attr

    @property
    def all_special_tokens(self):
        """ List all the special tokens ('<unk>', '<cls>'...) mapped to class attributes
            (cls_token, unk_token...).
        """
        all_toks = []
        set_attr = self.special_tokens_map
        for attr_value in set_attr.values():
epwalsh's avatar
epwalsh committed
972
            all_toks = all_toks + (list(attr_value) if isinstance(attr_value, (list, tuple)) else [attr_value])
973
974
975
976
977
978
979
980
981
        all_toks = list(set(all_toks))
        return all_toks

    @property
    def all_special_ids(self):
        """ List the vocabulary indices of the special tokens ('<unk>', '<cls>'...) mapped to
            class attributes (cls_token, unk_token...).
        """
        all_toks = self.all_special_tokens
982
        all_ids = list(self._convert_token_to_id(t) for t in all_toks)
983
984
        return all_ids

thomwolf's avatar
thomwolf committed
985
986
    @staticmethod
    def clean_up_tokenization(out_string):
987
988
        """ Clean up a list of simple English tokenization artifacts like spaces before punctuations and abreviated forms.
        """
thomwolf's avatar
thomwolf committed
989
990
991
992
        out_string = out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ','
                        ).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't"
                        ).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re")
        return out_string