tokenization_bart.py 7.79 KB
Newer Older
Sam Shleifer's avatar
Sam Shleifer committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright 2020 The Facebook AI Research Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

sgugger's avatar
Fix CI  
sgugger committed
16
from typing import List, Optional
Sam Shleifer's avatar
Sam Shleifer committed
17

18
from .tokenization_roberta import RobertaTokenizer
sgugger's avatar
Fix CI  
sgugger committed
19
from .tokenization_utils_base import BatchEncoding
Lysandre Debut's avatar
Lysandre Debut committed
20
from .utils import logging
Sam Shleifer's avatar
Sam Shleifer committed
21
22


Lysandre Debut's avatar
Lysandre Debut committed
23
logger = logging.get_logger(__name__)
Sam Shleifer's avatar
Sam Shleifer committed
24
25
26
27
28


# vocab and merges same as roberta
vocab_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json"
merges_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt"
29
_all_bart_models = [
Sam Shleifer's avatar
Sam Shleifer committed
30
    "facebook/bart-base",
31
32
33
34
    "facebook/bart-large",
    "facebook/bart-large-mnli",
    "facebook/bart-large-cnn",
    "facebook/bart-large-xsum",
Yacine Jernite's avatar
Yacine Jernite committed
35
    "yjernite/bart_eli5",
36
    # This is not exhaustive: see https://huggingface.co/models?filter=bart
37
]
Sam Shleifer's avatar
Sam Shleifer committed
38
39
40


class BartTokenizer(RobertaTokenizer):
41
42
43
44
45
46
47
48
49
    r"""
    Construct a BART tokenizer.

    :class:`~transformers.BartTokenizer` is identical to :class:`~transformers.RobertaTokenizer` and adds a new
    :meth:`~transformers.BartTokenizer.prepare_seq2seq_batch`

    Refer to superclass :class:`~transformers.RobertaTokenizer` for usage examples and documentation concerning
    the initialization parameters and other methods.
    """
Sam Shleifer's avatar
Sam Shleifer committed
50
51
52
53
54
55
    # merges and vocab same as Roberta
    max_model_input_sizes = {m: 1024 for m in _all_bart_models}
    pretrained_vocab_files_map = {
        "vocab_file": {m: vocab_url for m in _all_bart_models},
        "merges_file": {m: merges_url for m in _all_bart_models},
    }
Sam Shleifer's avatar
Sam Shleifer committed
56

57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
    def prepare_seq2seq_batch(
        self,
        src_texts: List[str],
        tgt_texts: Optional[List[str]] = None,
        max_length: Optional[int] = None,
        max_target_length: Optional[int] = None,
        padding: str = "longest",
        return_tensors: str = "None",
        truncation=True,
        **kwargs,
    ) -> BatchEncoding:
        r"""

        Prepare a batch that can be passed directly to an instance of :class:`~transformers.BartModel`.

        Args:
            src_texts: (:obj:`List[str]`):
                List of documents to summarize or source language texts.
            tgt_texts: (:obj:`List[str]`, `optional`):
                List of summaries or target language texts.
            max_length (:obj:`int`, `optional`):
                Controls the maximum length for encoder inputs (documents to summarize or source language texts).
                If left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum
                length is required by one of the truncation/padding parameters. If the model has no specific maximum
                input length (like XLNet) truncation/padding to a maximum length will be deactivated.
            max_target_length (:obj:`int`, `optional`):
                Controls the maximum length of decoder inputs (target language texts or summaries).
                If left unset or set to :obj:`None`, this will use the max_length value.
            padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`):
                Activates and controls padding. Accepts the following values:

                * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
                  single sequence if provided).
                * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
                  maximum acceptable input length for the model if that argument is not provided.
                * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
                  different lengths).
            return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`, defaults to "pt"):
                If set, will return tensors instead of list of python integers. Acceptable values are:

                * :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
                * :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
                * :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
            truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`True`):
                Activates and controls truncation. Accepts the following values:

                * :obj:`True` or :obj:`'longest_first'`: Truncate to a maximum length specified with the argument
                  :obj:`max_length` or to the maximum acceptable input length for the model if that argument is not
                  provided. This will truncate token by token, removing a token from the longest sequence in the pair
                  if a pair of sequences (or a batch of pairs) is provided.
                * :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to
                  the maximum acceptable input length for the model if that argument is not provided. This will only
                  truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
                * :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or
                  to the maximum acceptable input length for the model if that argument is not provided. This will only
                  truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
                * :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with
                  sequence lengths greater than the model maximum admissible input size).
            **kwargs:
                Additional keyword arguments passed along to :obj:`self.__call__`.

        Returns:
            :class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields:

            - **input_ids** -- List of token ids to be fed to the encoder.
            - **attention_mask** -- List of indices specifying which tokens should be attended to by the model.
123
            - **labels** -- List of token ids for tgt_texts
124
125
126
127

            The full set of keys ``[input_ids, attention_mask, decoder_input_ids,  decoder_attention_mask]``,
            will only be returned if tgt_texts is passed. Otherwise, input_ids, attention_mask will be the only keys.
        """
128
129
        kwargs.pop("src_lang", None)
        kwargs.pop("tgt_lang", None)
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
        if max_length is None:
            max_length = self.model_max_length
        model_inputs: BatchEncoding = self(
            src_texts,
            add_special_tokens=True,
            return_tensors=return_tensors,
            max_length=max_length,
            padding=padding,
            truncation=truncation,
            **kwargs,
        )
        if tgt_texts is None:
            return model_inputs
        # Process tgt_texts
        if max_target_length is None:
            max_target_length = max_length
146
        labels = self(
147
148
149
150
151
152
153
            tgt_texts,
            add_special_tokens=True,
            return_tensors=return_tensors,
            padding=padding,
            max_length=max_target_length,
            truncation=truncation,
            **kwargs,
154
155
        )["input_ids"]
        model_inputs["labels"] = labels
156
        return model_inputs