tokenization_mbart50.py 16.3 KB
Newer Older
Suraj Patil's avatar
Suraj Patil committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# coding=utf-8
# Copyright 2021 The Facebook AI Research Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from shutil import copyfile
18
from typing import Any, Dict, List, Optional, Tuple
Suraj Patil's avatar
Suraj Patil committed
19
20
21
22
23
24
25
26
27
28
29
30
31

import sentencepiece as spm

from ...tokenization_utils import AddedToken, BatchEncoding, PreTrainedTokenizer
from ...utils import logging


logger = logging.get_logger(__name__)

SPIECE_UNDERLINE = "▁"

VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model"}

32
33
PRETRAINED_VOCAB_FILES_MAP = {
    "vocab_file": {
Sylvain Gugger's avatar
Sylvain Gugger committed
34
35
36
        "facebook/mbart-large-50-one-to-many-mmt": (
            "https://huggingface.co/facebook/mbart-large-50-one-to-many-mmt/resolve/main/sentencepiece.bpe.model"
        ),
37
38
39
40
41
42
    }
}

PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
    "facebook/mbart-large-50-one-to-many-mmt": 1024,
}
Suraj Patil's avatar
Suraj Patil committed
43
44
45
46
47
48
49
50

# fmt: off
FAIRSEQ_LANGUAGE_CODES = ["ar_AR", "cs_CZ", "de_DE", "en_XX", "es_XX", "et_EE", "fi_FI", "fr_XX", "gu_IN", "hi_IN", "it_IT", "ja_XX", "kk_KZ", "ko_KR", "lt_LT", "lv_LV", "my_MM", "ne_NP", "nl_XX", "ro_RO", "ru_RU", "si_LK", "tr_TR", "vi_VN", "zh_CN", "af_ZA", "az_AZ", "bn_IN", "fa_IR", "he_IL", "hr_HR", "id_ID", "ka_GE", "km_KH", "mk_MK", "ml_IN", "mn_MN", "mr_IN", "pl_PL", "ps_AF", "pt_XX", "sv_SE", "sw_KE", "ta_IN", "te_IN", "th_TH", "tl_XX", "uk_UA", "ur_PK", "xh_ZA", "gl_ES", "sl_SI"]
# fmt: on


class MBart50Tokenizer(PreTrainedTokenizer):
    """
51
    Construct a MBart50 tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece).
Suraj Patil's avatar
Suraj Patil committed
52

Sylvain Gugger's avatar
Sylvain Gugger committed
53
54
    This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
    this superclass for more information regarding those methods.
Suraj Patil's avatar
Suraj Patil committed
55
56

    Args:
57
        vocab_file (`str`):
Suraj Patil's avatar
Suraj Patil committed
58
            Path to the vocabulary file.
59
        src_lang (`str`, *optional*):
Suraj Patil's avatar
Suraj Patil committed
60
            A string representing the source language.
61
        tgt_lang (`str`, *optional*):
Suraj Patil's avatar
Suraj Patil committed
62
            A string representing the target language.
63
        eos_token (`str`, *optional*, defaults to `"</s>"`):
Suraj Patil's avatar
Suraj Patil committed
64
            The end of sequence token.
65
        sep_token (`str`, *optional*, defaults to `"</s>"`):
Suraj Patil's avatar
Suraj Patil committed
66
67
68
            The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
            sequence classification or for a text and a question for question answering. It is also used as the last
            token of a sequence built with special tokens.
69
        cls_token (`str`, *optional*, defaults to `"<s>"`):
Suraj Patil's avatar
Suraj Patil committed
70
71
            The classifier token which is used when doing sequence classification (classification of the whole sequence
            instead of per-token classification). It is the first token of the sequence when built with special tokens.
72
        unk_token (`str`, *optional*, defaults to `"<unk>"`):
Suraj Patil's avatar
Suraj Patil committed
73
74
            The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
            token instead.
75
        pad_token (`str`, *optional*, defaults to `"<pad>"`):
Suraj Patil's avatar
Suraj Patil committed
76
            The token used for padding, for example when batching sequences of different lengths.
77
        mask_token (`str`, *optional*, defaults to `"<mask>"`):
Suraj Patil's avatar
Suraj Patil committed
78
79
            The token used for masking values. This is the token used when training this model with masked language
            modeling. This is the token which the model will try to predict.
80
        sp_model_kwargs (`dict`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
81
82
83
            Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
            SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
            to set:
84

85
86
            - `enable_sampling`: Enable subword regularization.
            - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
87

88
89
90
              - `nbest_size = {0,1}`: No sampling is performed.
              - `nbest_size > 1`: samples from the nbest_size results.
              - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
91
92
                using forward-filtering-and-backward-sampling algorithm.

93
            - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
94
              BPE-dropout.
Suraj Patil's avatar
Suraj Patil committed
95

96
    Examples:
Suraj Patil's avatar
Suraj Patil committed
97

98
99
    ```python
    >>> from transformers import MBart50Tokenizer
Sylvain Gugger's avatar
Sylvain Gugger committed
100

101
102
    >>> tokenizer = MBart50Tokenizer.from_pretrained("facebook/mbart-large-50", src_lang="en_XX", tgt_lang="ro_RO")
    >>> src_text = " UN Chief Says There Is No Military Solution in Syria"
Sylvain Gugger's avatar
Sylvain Gugger committed
103
    >>> tgt_text = "Şeful ONU declară că nu există o soluţie militară în Siria"
104
105
    >>> model_inputs = tokenizer(src_text, text_target=tgt_text, return_tensors="pt")
    >>> # model(**model_inputs) should work
106
    ```"""
Suraj Patil's avatar
Suraj Patil committed
107
108

    vocab_files_names = VOCAB_FILES_NAMES
109
110
    max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
    pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
Suraj Patil's avatar
Suraj Patil committed
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
    model_input_names = ["input_ids", "attention_mask"]

    prefix_tokens: List[int] = []
    suffix_tokens: List[int] = []

    def __init__(
        self,
        vocab_file,
        src_lang=None,
        tgt_lang=None,
        eos_token="</s>",
        sep_token="</s>",
        cls_token="<s>",
        unk_token="<unk>",
        pad_token="<pad>",
        mask_token="<mask>",
127
        sp_model_kwargs: Optional[Dict[str, Any]] = None,
Suraj Patil's avatar
Suraj Patil committed
128
        **kwargs
129
    ) -> None:
Suraj Patil's avatar
Suraj Patil committed
130
131
132
        # Mask token behave like a normal word, i.e. include the space before it
        mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token

133
134
        self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs

135
136
137
138
139
        kwargs["additional_special_tokens"] = kwargs.get("additional_special_tokens", [])
        kwargs["additional_special_tokens"] += [
            code for code in FAIRSEQ_LANGUAGE_CODES if code not in kwargs["additional_special_tokens"]
        ]

Suraj Patil's avatar
Suraj Patil committed
140
141
142
143
144
145
146
147
148
        super().__init__(
            src_lang=src_lang,
            tgt_lang=tgt_lang,
            eos_token=eos_token,
            unk_token=unk_token,
            sep_token=sep_token,
            cls_token=cls_token,
            pad_token=pad_token,
            mask_token=mask_token,
149
            sp_model_kwargs=self.sp_model_kwargs,
Suraj Patil's avatar
Suraj Patil committed
150
151
152
            **kwargs,
        )

153
        self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
Suraj Patil's avatar
Suraj Patil committed
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
        self.sp_model.Load(str(vocab_file))
        self.vocab_file = vocab_file

        # Original fairseq vocab and spm vocab must be "aligned":
        # Vocab    |    0    |    1    |   2    |    3    |  4  |  5  |  6  |   7   |   8   |  9
        # -------- | ------- | ------- | ------ | ------- | --- | --- | --- | ----- | ----- | ----
        # fairseq  | '<s>'   | '<pad>' | '</s>' | '<unk>' | ',' | '.' | '▁' | 's'   | '▁de' | '-'
        # spm      | '<unk>' | '<s>'   | '</s>' | ','     | '.' | '▁' | 's' | '▁de' | '-'   | '▁a'

        # Mimic fairseq token-to-id alignment for the first 4 token
        self.fairseq_tokens_to_ids = {"<s>": 0, "<pad>": 1, "</s>": 2, "<unk>": 3}

        # The first "real" token "," has position 4 in the original fairseq vocab and position 3 in the spm vocab
        self.fairseq_offset = 1

        self.sp_model_size = len(self.sp_model)
        self.lang_code_to_id = {
            code: self.sp_model_size + i + self.fairseq_offset for i, code in enumerate(FAIRSEQ_LANGUAGE_CODES)
        }
        self.id_to_lang_code = {v: k for k, v in self.lang_code_to_id.items()}
        self.fairseq_tokens_to_ids["<mask>"] = len(self.sp_model) + len(self.lang_code_to_id) + self.fairseq_offset

        self.fairseq_tokens_to_ids.update(self.lang_code_to_id)
        self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}

        self._src_lang = src_lang if src_lang is not None else "en_XX"
        self.cur_lang_code_id = self.lang_code_to_id[self._src_lang]
        self.tgt_lang = tgt_lang
        self.set_src_lang_special_tokens(self._src_lang)

    @property
    def vocab_size(self) -> int:
        return len(self.sp_model) + len(self.lang_code_to_id) + self.fairseq_offset + 1  # Plus 1 for the mask token

    @property
    def src_lang(self) -> str:
        return self._src_lang

    @src_lang.setter
    def src_lang(self, new_src_lang: str) -> None:
        self._src_lang = new_src_lang
        self.set_src_lang_special_tokens(self._src_lang)

    def __getstate__(self) -> Dict:
        state = self.__dict__.copy()
        state["sp_model"] = None
        return state

    def __setstate__(self, d: Dict) -> None:
        self.__dict__ = d
204
205
206
207
208
209

        # for backward compatibility
        if not hasattr(self, "sp_model_kwargs"):
            self.sp_model_kwargs = {}

        self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
Suraj Patil's avatar
Suraj Patil committed
210
211
212
213
214
215
216
217
        self.sp_model.Load(self.vocab_file)

    def get_vocab(self) -> Dict:
        vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
        vocab.update(self.added_tokens_encoder)
        return vocab

    def _tokenize(self, text: str) -> List[str]:
218
        return self.sp_model.encode(text, out_type=str)
Suraj Patil's avatar
Suraj Patil committed
219
220

    def _convert_token_to_id(self, token: str) -> int:
Patrick von Platen's avatar
Patrick von Platen committed
221
        """Converts a token (str) in an id using the vocab."""
Suraj Patil's avatar
Suraj Patil committed
222
223
224
225
226
227
228
229
230
231
232
233
234
        if token in self.fairseq_tokens_to_ids:
            return self.fairseq_tokens_to_ids[token]
        spm_id = self.sp_model.PieceToId(token)

        # Need to return unknown token if the SP model returned 0
        return spm_id + self.fairseq_offset if spm_id else self.unk_token_id

    def _convert_id_to_token(self, index: int) -> str:
        """Converts an index (integer) in a token (str) using the vocab."""
        if index in self.fairseq_ids_to_tokens:
            return self.fairseq_ids_to_tokens[index]
        return self.sp_model.IdToPiece(index - self.fairseq_offset)

235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
    def convert_tokens_to_string(self, tokens):
        """Converts a sequence of tokens (string) in a single string."""
        current_sub_tokens = []
        out_string = ""
        prev_is_special = False
        for token in tokens:
            # make sure that special tokens are not decoded using sentencepiece model
            if token in self.all_special_tokens:
                if not prev_is_special:
                    out_string += " "
                out_string += self.sp_model.decode(current_sub_tokens) + token
                prev_is_special = True
                current_sub_tokens = []
            else:
                current_sub_tokens.append(token)
                prev_is_special = False
        out_string += self.sp_model.decode(current_sub_tokens)
        return out_string.strip()
Suraj Patil's avatar
Suraj Patil committed
253
254
255

    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
        if not os.path.isdir(save_directory):
256
            logger.error(f"Vocabulary path ({save_directory}) should be a directory")
Suraj Patil's avatar
Suraj Patil committed
257
258
259
260
261
            return
        out_vocab_file = os.path.join(
            save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
        )

262
        if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
Suraj Patil's avatar
Suraj Patil committed
263
            copyfile(self.vocab_file, out_vocab_file)
264
265
266
267
        elif not os.path.isfile(self.vocab_file):
            with open(out_vocab_file, "wb") as fi:
                content_spiece_model = self.sp_model.serialized_model_proto()
                fi.write(content_spiece_model)
Suraj Patil's avatar
Suraj Patil committed
268
269
270
271
272
273
274
275

        return (out_vocab_file,)

    def get_special_tokens_mask(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
    ) -> List[int]:
        """
        Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
276
        special tokens using the tokenizer `prepare_for_model` method.
Suraj Patil's avatar
Suraj Patil committed
277
278

        Args:
279
            token_ids_0 (`List[int]`):
Suraj Patil's avatar
Suraj Patil committed
280
                List of IDs.
281
            token_ids_1 (`List[int]`, *optional*):
Suraj Patil's avatar
Suraj Patil committed
282
                Optional second list of IDs for sequence pairs.
283
            already_has_special_tokens (`bool`, *optional*, defaults to `False`):
Suraj Patil's avatar
Suraj Patil committed
284
285
286
                Whether or not the token list is already formatted with special tokens for the model.

        Returns:
287
            `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
Suraj Patil's avatar
Suraj Patil committed
288
289
290
        """

        if already_has_special_tokens:
291
292
293
294
            return super().get_special_tokens_mask(
                token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
            )

Suraj Patil's avatar
Suraj Patil committed
295
296
297
298
299
300
301
302
303
304
305
        prefix_ones = [1] * len(self.prefix_tokens)
        suffix_ones = [1] * len(self.suffix_tokens)
        if token_ids_1 is None:
            return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones
        return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones

    def build_inputs_with_special_tokens(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        """
        Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
306
        adding special tokens. An MBART-50 sequence has the following format, where `X` represents the sequence:
Suraj Patil's avatar
Suraj Patil committed
307

308
309
        - `input_ids` (for encoder) `[src_lang_code] X [eos]`
        - `labels`: (for decoder) `[tgt_lang_code] X [eos]`
Suraj Patil's avatar
Suraj Patil committed
310
311
312
313
314

        BOS is never used. Pairs of sequences are not the expected use case, but they will be handled without a
        separator.

        Args:
315
            token_ids_0 (`List[int]`):
Suraj Patil's avatar
Suraj Patil committed
316
                List of IDs to which the special tokens will be added.
317
            token_ids_1 (`List[int]`, *optional*):
Suraj Patil's avatar
Suraj Patil committed
318
319
320
                Optional second list of IDs for sequence pairs.

        Returns:
321
            `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
Suraj Patil's avatar
Suraj Patil committed
322
323
324
325
326
327
        """
        if token_ids_1 is None:
            return self.prefix_tokens + token_ids_0 + self.suffix_tokens
        # We don't expect to process pairs, but leave the pair logic for API consistency
        return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens

328
329
330
    def _build_translation_inputs(
        self, raw_inputs, return_tensors: str, src_lang: Optional[str], tgt_lang: Optional[str], **extra_kwargs
    ):
331
332
333
334
        """Used by translation pipeline, to prepare inputs for the generate function"""
        if src_lang is None or tgt_lang is None:
            raise ValueError("Translation requires a `src_lang` and a `tgt_lang` for this model")
        self.src_lang = src_lang
335
        inputs = self(raw_inputs, add_special_tokens=True, return_tensors=return_tensors, **extra_kwargs)
336
337
338
339
        tgt_lang_id = self.convert_tokens_to_ids(tgt_lang)
        inputs["forced_bos_token_id"] = tgt_lang_id
        return inputs

Suraj Patil's avatar
Suraj Patil committed
340
341
342
343
344
345
346
347
348
349
350
351
    def prepare_seq2seq_batch(
        self,
        src_texts: List[str],
        src_lang: str = "en_XX",
        tgt_texts: Optional[List[str]] = None,
        tgt_lang: str = "ro_RO",
        **kwargs,
    ) -> BatchEncoding:
        self.src_lang = src_lang
        self.tgt_lang = tgt_lang
        return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)

352
353
354
355
356
    def _switch_to_input_mode(self):
        return self.set_src_lang_special_tokens(self.src_lang)

    def _switch_to_target_mode(self):
        return self.set_tgt_lang_special_tokens(self.tgt_lang)
Suraj Patil's avatar
Suraj Patil committed
357
358
359
360
361
362
363
364
365
366
367
368

    def set_src_lang_special_tokens(self, src_lang: str) -> None:
        """Reset the special tokens to the source lang setting. prefix=[src_lang_code] and suffix=[eos]."""
        self.cur_lang_code_id = self.lang_code_to_id[src_lang]
        self.prefix_tokens = [self.cur_lang_code_id]
        self.suffix_tokens = [self.eos_token_id]

    def set_tgt_lang_special_tokens(self, tgt_lang: str) -> None:
        """Reset the special tokens to the target language setting. prefix=[tgt_lang_code] and suffix=[eos]."""
        self.cur_lang_code_id = self.lang_code_to_id[tgt_lang]
        self.prefix_tokens = [self.cur_lang_code_id]
        self.suffix_tokens = [self.eos_token_id]