"pytorch_pretrained_bert/modeling_bert.py" did not exist on "bfd6f6b257f2d4857f65bbcd6cb3487123fe848f"
tokenization_utils.py 30.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# coding=utf-8
# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for OpenAI GPT."""
from __future__ import (absolute_import, division, print_function,
                        unicode_literals)

import logging
import os
21
22
import json
import six
23
24
25
26
27
28
from io import open

from .file_utils import cached_path

logger = logging.getLogger(__name__)

29
30
SPECIAL_TOKENS_MAP_FILE = 'special_tokens_map.json'
ADDED_TOKENS_FILE = 'added_tokens.json'
31
32

class PreTrainedTokenizer(object):
33
34
    """ Base class for all tokenizers.
    Handle all the shared methods for tokenization and special tokens as well as methods dowloading/caching/loading pretrained tokenizers as well as adding tokens to the vocabulary.
35

36
    This class also contain the added tokens in a unified way on top of all tokenizers so we don't have to handle the specific vocabulary augmentation methods of the various underlying dictionary structures (BPE, sentencepiece...).
37

38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
    Class attributes (overridden by derived classes):

        - ``vocab_files_names``: a python ``dict`` with, as keys, the ``__init__`` keyword name of each vocabulary file required by the model, and as associated values, the filename for saving the associated file (string).
        - ``pretrained_vocab_files_map``: a python ``dict of dict`` the high-level keys being the ``__init__`` keyword name of each vocabulary file required by the model, the low-level being the `short-cut-names` (string) of the pretrained models with, as associated values, the `url` (string) to the associated pretrained vocabulary file.
        - ``max_model_input_sizes``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the pretrained models, and as associated values, the maximum length of the sequence inputs of this model, or None if the model has no maximum input size.

    Parameters:

        - ``bos_token``: (`Optional`) string: a beginning of sentence token. Will be associated to ``self.bos_token``

        - ``eos_token``: (`Optional`) string: an end of sentence token. Will be associated to ``self.eos_token``

        - ``unk_token``: (`Optional`) string: an unknown token. Will be associated to ``self.unk_token``

        - ``sep_token``: (`Optional`) string: a separation token (e.g. to separate context and query in an input sequence). Will be associated to ``self.sep_token``

        - ``pad_token``: (`Optional`) string: a padding token. Will be associated to ``self.pad_token``

        - ``cls_token``: (`Optional`) string: a classification token (e.g. to extract a summary of an input sequence leveraging self-attention along the full depth of the model). Will be associated to ``self.cls_token``

        - ``mask_token``: (`Optional`) string: a masking token (e.g. when training a model with masked-language modeling). Will be associated to ``self.mask_token``

        - ``additional_special_tokens``: (`Optional`) list: a list of additional special tokens. Adding all special tokens here ensure they won't be split by the tokenization process. Will be associated to ``self.additional_special_tokens``
61
62
63
64
65
    """
    vocab_files_names = {}
    pretrained_vocab_files_map = {}
    max_model_input_sizes = {}

66
67
68
69
70
71
    SPECIAL_TOKENS_ATTRIBUTES = ["bos_token", "eos_token", "unk_token", "sep_token",
                                 "pad_token", "cls_token", "mask_token",
                                 "additional_special_tokens"]

    @property
    def bos_token(self):
72
        """ Beginning of sentence token (string). Log an error if used while not having been set. """
73
74
75
76
77
78
        if self._bos_token is None:
            logger.error("Using bos_token, but it is not set yet.")
        return self._bos_token

    @property
    def eos_token(self):
79
        """ End of sentence token (string). Log an error if used while not having been set. """
80
81
82
83
84
85
        if self._eos_token is None:
            logger.error("Using eos_token, but it is not set yet.")
        return self._eos_token

    @property
    def unk_token(self):
86
        """ Unknown token (string). Log an error if used while not having been set. """
87
88
89
90
91
92
        if self._unk_token is None:
            logger.error("Using unk_token, but it is not set yet.")
        return self._unk_token

    @property
    def sep_token(self):
93
        """ Separation token (string). E.g. separate context and query in an input sequence. Log an error if used while not having been set. """
94
95
96
97
98
99
        if self._sep_token is None:
            logger.error("Using sep_token, but it is not set yet.")
        return self._sep_token

    @property
    def pad_token(self):
100
        """ Padding token (string). Log an error if used while not having been set. """
101
102
103
104
105
106
        if self._pad_token is None:
            logger.error("Using pad_token, but it is not set yet.")
        return self._pad_token

    @property
    def cls_token(self):
107
        """ Classification token (string). E.g. to extract a summary of an input sequence leveraging self-attention along the full depth of the model. Log an error if used while not having been set. """
108
109
110
111
112
113
        if self._cls_token is None:
            logger.error("Using cls_token, but it is not set yet.")
        return self._cls_token

    @property
    def mask_token(self):
114
        """ Mask token (string). E.g. when training a model with masked-language modeling. Log an error if used while not having been set. """
115
116
117
118
119
120
        if self._mask_token is None:
            logger.error("Using mask_token, but it is not set yet.")
        return self._mask_token

    @property
    def additional_special_tokens(self):
121
        """ All the additional special tokens you may want to use (list of strings). Log an error if used while not having been set. """
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
        if self._additional_special_tokens is None:
            logger.error("Using additional_special_tokens, but it is not set yet.")
        return self._additional_special_tokens

    @bos_token.setter
    def bos_token(self, value):
        self._bos_token = value

    @eos_token.setter
    def eos_token(self, value):
        self._eos_token = value

    @unk_token.setter
    def unk_token(self, value):
        self._unk_token = value

    @sep_token.setter
    def sep_token(self, value):
        self._sep_token = value

    @pad_token.setter
    def pad_token(self, value):
        self._pad_token = value

    @cls_token.setter
    def cls_token(self, value):
        self._cls_token = value

    @mask_token.setter
    def mask_token(self, value):
        self._mask_token = value

    @additional_special_tokens.setter
    def additional_special_tokens(self, value):
        self._additional_special_tokens = value

    def __init__(self, max_len=None, **kwargs):
        self._bos_token = None
        self._eos_token = None
        self._unk_token = None
        self._sep_token = None
        self._pad_token = None
        self._cls_token = None
        self._mask_token = None
        self._additional_special_tokens = []

        self.max_len = max_len if max_len is not None else int(1e12)
        self.added_tokens_encoder = {}
        self.added_tokens_decoder = {}

        for key, value in kwargs.items():
173
            if key in self.SPECIAL_TOKENS_ATTRIBUTES:
174
175
176
177
                if key == 'additional_special_tokens':
                    assert isinstance(value, (list, tuple)) and all(isinstance(t, str) or (six.PY2 and isinstance(t, unicode)) for t in value)
                else:
                    assert isinstance(value, str) or (six.PY2 and isinstance(value, unicode))
178
179
180
                setattr(self, key, value)


181
182
    @classmethod
    def from_pretrained(cls, *inputs, **kwargs):
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
        r""" Instantiate a :class:`~pytorch_transformers.PreTrainedTokenizer` (or a derived class) from a predefined tokenizer.

        Parameters:
            pretrained_model_name_or_path: either:

                - a string with the `shortcut name` of a predefined tokenizer to load from cache or download, e.g.: ``bert-base-uncased``.
                - a path to a `directory` containing vocabulary files required by the tokenizer, for instance saved using the :func:`~pytorch_transformers.PreTrainedTokenizer.save_pretrained` method, e.g.: ``./my_model_directory/``.
                - (not applicable to all derived classes) a path or url to a single saved vocabulary file if and only if the tokenizer only requires a single vocabulary file (e.g. Bert, XLNet), e.g.: ``./my_model_directory/vocab.txt``.

            cache_dir: (`optional`) string:
                Path to a directory in which a downloaded predefined tokenizer vocabulary files should be cached if the standard cache should not be used.

            inputs: (`optional`) positional arguments: will be passed to the Tokenizer ``__init__`` method.

            kwargs: (`optional`) keyword arguments: will be passed to the Tokenizer ``__init__`` method. Can be used to set special tokens like ``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``, ``additional_special_tokens``. See parameters in the doc string of :class:`~pytorch_transformers.PreTrainedTokenizer` for details.

        Examples::

            # We can't instantiate directly the base class `PreTrainedTokenizer` so let's show our examples on a derived class: BertTokenizer

            # Download vocabulary from S3 and cache.
            tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

            # If vocabulary files are in a directory (e.g. tokenizer was saved using `save_pretrained('./test/saved_model/')`)
            tokenizer = BertTokenizer.from_pretrained('./test/saved_model/')

            # If the tokenizer uses a single vocabulary file, you can point directly to this file
            tokenizer = BertTokenizer.from_pretrained('./test/saved_model/my_vocab.txt')

            # You can link tokens to special vocabulary when instantiating
            tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', unk_token='<unk>')
            # You should be sure '<unk>' is in the vocabulary when doing that.
            # Otherwise use tokenizer.add_special_tokens({'unk_token': '<unk>'}) instead)
            assert tokenizer.unk_token == '<unk>'

        """
219
220
        return cls._from_pretrained(*inputs, **kwargs)

221

222
    @classmethod
thomwolf's avatar
thomwolf committed
223
224
225
    def _from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
        cache_dir = kwargs.pop('cache_dir', None)

226
227
228
        s3_models = list(cls.max_model_input_sizes.keys())
        vocab_files = {}
        if pretrained_model_name_or_path in s3_models:
thomwolf's avatar
thomwolf committed
229
            # Get the vocabulary from AWS S3 bucket
230
231
232
            for file_id, map_list in cls.pretrained_vocab_files_map.items():
                vocab_files[file_id] = map_list[pretrained_model_name_or_path]
        else:
thomwolf's avatar
thomwolf committed
233
            # Get the vocabulary from local files
234
235
236
237
238
            logger.info(
                "Model name '{}' not found in model shortcut name list ({}). "
                "Assuming '{}' is a path or url to a directory containing tokenizer files.".format(
                    pretrained_model_name_or_path, ', '.join(s3_models),
                    pretrained_model_name_or_path))
thomwolf's avatar
thomwolf committed
239
240
241

            # Look for the tokenizer main vocabulary files
            for file_id, file_name in cls.vocab_files_names.items():
242
                if os.path.isdir(pretrained_model_name_or_path):
thomwolf's avatar
thomwolf committed
243
                    # If a directory is provided we look for the standard filenames
244
245
                    full_file_name = os.path.join(pretrained_model_name_or_path, file_name)
                else:
thomwolf's avatar
thomwolf committed
246
                    # If a path to a file is provided we use it (will only work for non-BPE tokenizer using a single vocabulary file)
247
248
                    full_file_name = pretrained_model_name_or_path
                if not os.path.exists(full_file_name):
249
                    logger.info("Didn't find file {}. We won't load it.".format(full_file_name))
250
251
                    full_file_name = None
                vocab_files[file_id] = full_file_name
thomwolf's avatar
thomwolf committed
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268

            # Look for the additional tokens files
            all_vocab_files_names = {'added_tokens_file': ADDED_TOKENS_FILE,
                                     'special_tokens_map_file': SPECIAL_TOKENS_MAP_FILE}

            # If a path to a file was provided, get the parent directory
            saved_directory = pretrained_model_name_or_path
            if os.path.exists(saved_directory) and not os.path.isdir(saved_directory):
                saved_directory = os.path.dirname(saved_directory)

            for file_id, file_name in all_vocab_files_names.items():
                full_file_name = os.path.join(saved_directory, file_name)
                if not os.path.exists(full_file_name):
                    logger.info("Didn't find file {}. We won't load it.".format(full_file_name))
                    full_file_name = None
                vocab_files[file_id] = full_file_name

269
270
271
272
273
274
275
276
            if all(full_file_name is None for full_file_name in vocab_files.values()):
                logger.error(
                    "Model name '{}' was not found in model name list ({}). "
                    "We assumed '{}' was a path or url but couldn't find tokenizer files"
                    "at this path or url.".format(
                        pretrained_model_name_or_path, ', '.join(s3_models),
                        pretrained_model_name_or_path, ))
                return None
277
278

        # Get files from url, cache, or disk depending on the case
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
        try:
            resolved_vocab_files = {}
            for file_id, file_path in vocab_files.items():
                if file_path is None:
                    resolved_vocab_files[file_id] = None
                else:
                    resolved_vocab_files[file_id] = cached_path(file_path, cache_dir=cache_dir)
        except EnvironmentError:
            if pretrained_model_name_or_path in s3_models:
                logger.error("Couldn't reach server to download vocabulary.")
            else:
                logger.error(
                    "Model name '{}' was not found in model name list ({}). "
                    "We assumed '{}' was a path or url but couldn't find files {} "
                    "at this path or url.".format(
                        pretrained_model_name_or_path, ', '.join(s3_models),
                        pretrained_model_name_or_path, str(vocab_files.keys())))
            return None

        for file_id, file_path in vocab_files.items():
            if file_path == resolved_vocab_files[file_id]:
                logger.info("loading file {}".format(file_path))
            else:
                logger.info("loading file {} from cache at {}".format(
                    file_path, resolved_vocab_files[file_id]))

305
        # Set max length if needed
306
307
308
309
        if pretrained_model_name_or_path in cls.max_model_input_sizes:
            # if we're using a pretrained model, ensure the tokenizer
            # wont index sequences longer than the number of positional embeddings
            max_len = cls.max_model_input_sizes[pretrained_model_name_or_path]
310
311
            if max_len is not None and isinstance(max_len, (int, float)):
                kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
312

thomwolf's avatar
thomwolf committed
313
        # Merge resolved_vocab_files arguments in kwargs.
314
315
        added_tokens_file = resolved_vocab_files.pop('added_tokens_file', None)
        special_tokens_map_file = resolved_vocab_files.pop('special_tokens_map_file', None)
thomwolf's avatar
thomwolf committed
316
        for args_name, file_path in resolved_vocab_files.items():
317
318
319
320
321
322
323
            if args_name not in kwargs:
                kwargs[args_name] = file_path
        if special_tokens_map_file is not None:
            special_tokens_map = json.load(open(special_tokens_map_file, encoding="utf-8"))
            for key, value in special_tokens_map.items():
                if key not in kwargs:
                    kwargs[key] = value
thomwolf's avatar
thomwolf committed
324

325
        # Instantiate tokenizer.
thomwolf's avatar
thomwolf committed
326
        tokenizer = cls(*inputs, **kwargs)
327

328
329
        # Add supplementary tokens.
        if added_tokens_file is not None:
thomwolf's avatar
thomwolf committed
330
            added_tok_encoder = json.load(open(added_tokens_file, encoding="utf-8"))
331
332
333
334
            added_tok_decoder = {v:k for k, v in added_tok_encoder.items()}
            tokenizer.added_tokens_encoder.update(added_tok_encoder)
            tokenizer.added_tokens_decoder.update(added_tok_decoder)

335
336
        return tokenizer

thomwolf's avatar
thomwolf committed
337

338
339
    def save_pretrained(self, save_directory):
        """ Save the tokenizer vocabulary files (with added tokens) and the
340
341
342
            special-tokens-to-class-attributes-mapping to a directory.

            This method make sure the full tokenizer can then be re-loaded using the :func:`~pytorch_transformers.PreTrainedTokenizer.from_pretrained` class method.
343
344
345
346
347
348
349
350
351
352
353
354
        """
        if not os.path.isdir(save_directory):
            logger.error("Saving directory ({}) should be a directory".format(save_directory))
            return

        special_tokens_map_file = os.path.join(save_directory, SPECIAL_TOKENS_MAP_FILE)
        added_tokens_file = os.path.join(save_directory, ADDED_TOKENS_FILE)

        with open(special_tokens_map_file, 'w', encoding='utf-8') as f:
            f.write(json.dumps(self.special_tokens_map, ensure_ascii=False))

        with open(added_tokens_file, 'w', encoding='utf-8') as f:
thomwolf's avatar
thomwolf committed
355
            if self.added_tokens_encoder:
356
                out_str = json.dumps(self.added_tokens_encoder, ensure_ascii=False)
thomwolf's avatar
thomwolf committed
357
358
359
            else:
                out_str = u"{}"
            f.write(out_str)
360
361
362
363
364
365
366

        vocab_files = self.save_vocabulary(save_directory)

        return vocab_files + (special_tokens_map_file, added_tokens_file)


    def save_vocabulary(self, save_directory):
367
        """ Save the tokenizer vocabulary to a directory. This method does *NOT* save added tokens
368
            and special token mappings.
369
370

            Please use :func:`~pytorch_transformers.PreTrainedTokenizer.save_pretrained` `()` to save the full Tokenizer state if you want to reload it using the :func:`~pytorch_transformers.PreTrainedTokenizer.from_pretrained` class method.
371
        """
thomwolf's avatar
thomwolf committed
372
373
        raise NotImplementedError

374
375

    def vocab_size(self):
376
        """ Size of the base vocabulary (without the added tokens) """
thomwolf's avatar
thomwolf committed
377
378
        raise NotImplementedError

379
380

    def __len__(self):
381
        """ Size of the full vocabulary with the added tokens """
382
383
384
385
386
        return self.vocab_size + len(self.added_tokens_encoder)


    def add_tokens(self, new_tokens):
        """ Add a list of new tokens to the tokenizer class. If the new tokens are not in the
387
388
389
390
        vocabulary, they are added to it with indices starting from length of the current vocabulary.

            Parameters:
                new_tokens: list of string. Each string is a token to add. Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them).
391
392

            Returns:
393
394
395
396
397
398
399
400
401
402
403
                Number of tokens added to the vocabulary.

        Examples::

            # Let's see how to increase the vocabulary of Bert model and tokenizer
            tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
            model = BertModel.from_pretrained('bert-base-uncased')

            num_added_toks = tokenizer.add_tokens(['new_tok1', 'my_new-tok2'])
            print('We have added', num_added_toks, 'tokens')
            model.resize_token_embeddings(len(tokenizer))  # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
404
405
406
407
408
409
        """
        if not new_tokens:
            return 0

        to_add_tokens = []
        for token in new_tokens:
410
            assert isinstance(token, str) or (six.PY2 and isinstance(token, unicode))
thomwolf's avatar
thomwolf committed
411
412
            if token != self.unk_token and \
                    self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token):
413
414
415
416
417
418
419
420
421
422
423
424
                to_add_tokens.append(token)
                logger.info("Adding %s to the vocabulary", token)

        added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(to_add_tokens))
        added_tok_decoder = {v:k for k, v in added_tok_encoder.items()}
        self.added_tokens_encoder.update(added_tok_encoder)
        self.added_tokens_decoder.update(added_tok_decoder)

        return len(to_add_tokens)


    def add_special_tokens(self, special_tokens_dict):
thomwolf's avatar
thomwolf committed
425
        """ Add a dictionary of special tokens (eos, pad, cls...) to the encoder and link them
426
427
428
429
430
431
432
            to class attributes. If special tokens are NOT in the vocabulary, they are added
            to it (indexed starting from the last index of the current vocabulary).

            Parameters:
                special_tokens_dict: dict of string. Keys should be in the list of predefined special attributes: [``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``, ``additional_special_tokens``].
                
                    Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them).
433

434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
            Returns:
                Number of tokens added to the vocabulary.

        Examples::

            # Let's see how to add a new classification token to GPT-2
            tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
            model = GPT2Model.from_pretrained('gpt2')

            special_tokens_dict = {'cls_token': '<CLS>'}

            num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
            print('We have added', num_added_toks, 'tokens')
            model.resize_token_embeddings(len(tokenizer))  # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.

            assert tokenizer.cls_token == '<CLS>'
450
451
452
453
        """
        if not special_tokens_dict:
            return 0

454
        added_tokens = 0
455
        for key, value in special_tokens_dict.items():
456
            assert key in self.SPECIAL_TOKENS_ATTRIBUTES
457
458
459
460
461
462
            if key == 'additional_special_tokens':
                assert isinstance(value, (list, tuple)) and all(isinstance(t, str) or (six.PY2 and isinstance(t, unicode)) for t in value)
                added_tokens += self.add_tokens(value)
            else:
                assert isinstance(value, str) or (six.PY2 and isinstance(value, unicode))
                added_tokens += self.add_tokens([value])
463
464
465
            logger.info("Assigning %s to the %s key of the tokenizer", value, key)
            setattr(self, key, value)

466
        return added_tokens
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484

    def tokenize(self, text, **kwargs):
        """ Converts a string in a sequence of tokens (string), using the tokenizer.
            Split in words for word-based vocabulary or sub-words for sub-word-based
            vocabularies (BPE/SentencePieces/WordPieces).

            Take care of added tokens.
        """
        def split_on_tokens(tok_list, text):
            if not text:
                return []
            if not tok_list:
                return self._tokenize(text, **kwargs)
            tok = tok_list[0]
            split_text = text.split(tok)
            return sum((split_on_tokens(tok_list[1:], sub_text.strip()) + [tok] \
                        for sub_text in split_text), [])[:-1]

485
        added_tokens = list(self.added_tokens_encoder.keys()) + self.all_special_tokens
486
487
488
489
490
491
492
493
        tokenized_text = split_on_tokens(added_tokens, text)
        return tokenized_text

    def _tokenize(self, text, **kwargs):
        """ Converts a string in a sequence of tokens (string), using the tokenizer.
            Split in words for word-based vocabulary or sub-words for sub-word-based
            vocabularies (BPE/SentencePieces/WordPieces).

494
            Do NOT take care of added tokens.
495
        """
thomwolf's avatar
thomwolf committed
496
497
        raise NotImplementedError

498
    def convert_tokens_to_ids(self, tokens):
499
500
        """ Converts a single token, or a sequence of tokens, (str/unicode) in a single integer id
            (resp. a sequence of ids), using the vocabulary.
501
502
        """
        if isinstance(tokens, str) or (six.PY2 and isinstance(tokens, unicode)):
503
            return self._convert_token_to_id_with_added_voc(tokens)
504
505
506

        ids = []
        for token in tokens:
507
            ids.append(self._convert_token_to_id_with_added_voc(token))
508
509
510
511
512
513
        if len(ids) > self.max_len:
            logger.warning("Token indices sequence length is longer than the specified maximum sequence length "
                           "for this model ({} > {}). Running this sequence through the model will result in "
                           "indexing errors".format(len(ids), self.max_len))
        return ids

514
    def _convert_token_to_id_with_added_voc(self, token):
515
516
517
518
519
        if token in self.added_tokens_encoder:
            return self.added_tokens_encoder[token]
        return self._convert_token_to_id(token)

    def _convert_token_to_id(self, token):
thomwolf's avatar
thomwolf committed
520
521
        raise NotImplementedError

522
    def encode(self, text, add_special_tokens=False, *sequences):
523
        """ Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
524
525
        
        Same doing ``self.convert_tokens_to_ids(self.tokenize(text))``.
526
        """
527

528
529
530
531
532
        if len(sequences) == 0:
            if add_special_tokens:
                return self.add_special_tokens_single_sentence(self.convert_tokens_to_ids(self.tokenize(text)))
            else:
                return self.convert_tokens_to_ids(self.tokenize(text))
533

534
        if len(sequences) > 1:
535
536
537
            logger.warning("Tokenization currently only supports sentence pairs. Ignoring every string following the "
                           "initial two.")

538
539
        first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text)]
        second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(sequences[0])]
540

541
542
543
544
        if add_special_tokens:
            return self.add_special_tokens_sentences_pair(first_sentence_tokens, second_sentence_tokens)
        else:
            return first_sentence_tokens, second_sentence_tokens
545

546
547
    def add_special_tokens_single_sentence(self, token_ids):
        raise NotImplementedError
548

549
550
    def add_special_tokens_sentences_pair(self, *token_ids):
        raise NotImplementedError
551

552
553
554
555
556
557
558
559
    def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
        """ Converts a single index or a sequence of indices (integers) in a token "
            (resp.) a sequence of tokens (str/unicode), using the vocabulary and added tokens.

            Args:
                skip_special_tokens: Don't decode special tokens (self.all_special_tokens). Default: False
        """
        if isinstance(ids, int):
560
561
562
563
            if ids in self.added_tokens_decoder:
                return self.added_tokens_decoder[ids]
            else:
                return self._convert_id_to_token(ids)
564
565
566
567
568
569
570
571
572
573
574
        tokens = []
        for index in ids:
            if index in self.all_special_ids and skip_special_tokens:
                continue
            if index in self.added_tokens_decoder:
                tokens.append(self.added_tokens_decoder[index])
            else:
                tokens.append(self._convert_id_to_token(index))
        return tokens

    def _convert_id_to_token(self, index):
thomwolf's avatar
thomwolf committed
575
576
        raise NotImplementedError

577
578
579
580
    def convert_tokens_to_string(self, tokens):
        """ Converts a sequence of tokens (string) in a single string.
            The most simple way to do it is ' '.join(self.convert_ids_to_tokens(token_ids))
            but we often want to remove sub-word tokenization artifacts at the same time.
581
        """
582
        return ' '.join(self.convert_ids_to_tokens(tokens))
583

584
    def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
585
586
        """ Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary
            with options to remove special tokens and clean up tokenization spaces.
587
588

        Similar to doing ``self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))``.
589
590
        """
        filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
591
        text = self.convert_tokens_to_string(filtered_tokens)
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606

        if self.sep_token is not None and self.sep_token in text:
            text = text.replace(self.cls_token, self.sep_token)
            split_text = list(filter(lambda sentence: len(sentence) > 0, text.split(self.sep_token)))
            if clean_up_tokenization_spaces:
                clean_text = [self.clean_up_tokenization(text) for text in split_text]
                return clean_text
            else:
                return split_text
        else:
            if clean_up_tokenization_spaces:
                clean_text = self.clean_up_tokenization(text)
                return clean_text
            else:
                return text
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637

    @property
    def special_tokens_map(self):
        """ A dictionary mapping special token class attribute (cls_token, unk_token...) to their
            values ('<unk>', '<cls>'...)
        """
        set_attr = {}
        for attr in self.SPECIAL_TOKENS_ATTRIBUTES:
            attr_value = getattr(self, "_" + attr)
            if attr_value:
                set_attr[attr] = attr_value
        return set_attr

    @property
    def all_special_tokens(self):
        """ List all the special tokens ('<unk>', '<cls>'...) mapped to class attributes
            (cls_token, unk_token...).
        """
        all_toks = []
        set_attr = self.special_tokens_map
        for attr_value in set_attr.values():
            all_toks = all_toks + (attr_value if isinstance(attr_value, (list, tuple)) else [attr_value])
        all_toks = list(set(all_toks))
        return all_toks

    @property
    def all_special_ids(self):
        """ List the vocabulary indices of the special tokens ('<unk>', '<cls>'...) mapped to
            class attributes (cls_token, unk_token...).
        """
        all_toks = self.all_special_tokens
638
        all_ids = list(self._convert_token_to_id(t) for t in all_toks)
639
640
        return all_ids

thomwolf's avatar
thomwolf committed
641
642
    @staticmethod
    def clean_up_tokenization(out_string):
643
644
        """ Clean up a list of simple English tokenization artifacts like spaces before punctuations and abreviated forms.
        """
thomwolf's avatar
thomwolf committed
645
646
647
648
        out_string = out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ','
                        ).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't"
                        ).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re")
        return out_string