tokenization_marian.py 10.7 KB
Newer Older
Sylvain Gugger's avatar
Sylvain Gugger committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import json
16
import re
17
import warnings
18
from contextlib import contextmanager
19
20
21
from pathlib import Path
from shutil import copyfile
from typing import Dict, List, Optional, Tuple, Union
22
23
24

import sentencepiece

25
from ...tokenization_utils import PreTrainedTokenizer
26
27
28
29
30
31
32
33


vocab_files_names = {
    "source_spm": "source.spm",
    "target_spm": "target.spm",
    "vocab": "vocab.json",
    "tokenizer_config_file": "tokenizer_config.json",
}
34
35
36
37
38
39
40
41
42
43
44
45
46

PRETRAINED_VOCAB_FILES_MAP = {
    "source_spm": {"Helsinki-NLP/opus-mt-en-de": "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-de/source.spm"},
    "target_spm": {"Helsinki-NLP/opus-mt-en-de": "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-de/target.spm"},
    "vocab": {"Helsinki-NLP/opus-mt-en-de": "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-de/vocab.json"},
    "tokenizer_config_file": {
        "Helsinki-NLP/opus-mt-en-de": "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-de/tokenizer_config.json"
    },
}

PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {"Helsinki-NLP/opus-mt-en-de": 512}
PRETRAINED_INIT_CONFIGURATION = {}

47
# Example URL https://huggingface.co/Helsinki-NLP/opus-mt-en-de/resolve/main/vocab.json
48
49


50
class MarianTokenizer(PreTrainedTokenizer):
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
    r"""
    Construct a Marian tokenizer. Based on `SentencePiece <https://github.com/google/sentencepiece>`__.

    This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main methods.
    Users should refer to this superclass for more information regarding those methods.

    Args:
        source_spm (:obj:`str`):
            `SentencePiece <https://github.com/google/sentencepiece>`__ file (generally has a .spm extension) that
            contains the vocabulary for the source language.
        target_spm (:obj:`str`):
            `SentencePiece <https://github.com/google/sentencepiece>`__ file (generally has a .spm extension) that
            contains the vocabulary for the target language.
        source_lang (:obj:`str`, `optional`):
            A string representing the source language.
        target_lang (:obj:`str`, `optional`):
            A string representing the target language.
        unk_token (:obj:`str`, `optional`, defaults to :obj:`"<unk>"`):
            The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
            token instead.
        eos_token (:obj:`str`, `optional`, defaults to :obj:`"</s>"`):
            The end of sequence token.
        pad_token (:obj:`str`, `optional`, defaults to :obj:`"<pad>"`):
            The token used for padding, for example when batching sequences of different lengths.
        model_max_length (:obj:`int`, `optional`, defaults to 512):
            The maximum sentence length the model accepts.
        additional_special_tokens (:obj:`List[str]`, `optional`, defaults to :obj:`["<eop>", "<eod>"]`):
            Additional special tokens used by the tokenizer.
79
80
81

    Examples::

82
83
84
85
        >>> from transformers import MarianTokenizer
        >>> tok = MarianTokenizer.from_pretrained('Helsinki-NLP/opus-mt-en-de')
        >>> src_texts = [ "I am a small frog.", "Tom asked his teacher for advice."]
        >>> tgt_texts = ["Ich bin ein kleiner Frosch.", "Tom bat seinen Lehrer um Rat."]  # optional
86
        >>> batch_enc = tok.prepare_seq2seq_batch(src_texts, tgt_texts=tgt_texts, return_tensors="pt")
87
        >>> # keys  [input_ids, attention_mask, labels].
88
        >>> # model(**batch) should work
89
90
    """

91
    vocab_files_names = vocab_files_names
92
93
94
    pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
    pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
    max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
95
    model_input_names = ["input_ids", "attention_mask"]
96
    language_code_re = re.compile(">>.+<<")  # type: re.Pattern
97
98
99

    def __init__(
        self,
100
101
102
        vocab,
        source_spm,
        target_spm,
103
104
105
106
107
        source_lang=None,
        target_lang=None,
        unk_token="<unk>",
        eos_token="</s>",
        pad_token="<pad>",
108
109
        model_max_length=512,
        **kwargs
110
111
    ):
        super().__init__(
112
            # bos_token=bos_token,  unused. Start decoding with config.decoder_start_token_id
113
114
            source_lang=source_lang,
            target_lang=target_lang,
115
            unk_token=unk_token,
116
            eos_token=eos_token,
117
            pad_token=pad_token,
118
            model_max_length=model_max_length,
119
            **kwargs,
120
        )
121
        assert Path(source_spm).exists(), f"cannot find spm source {source_spm}"
122
        self.encoder = load_json(vocab)
123
124
        if self.unk_token not in self.encoder:
            raise KeyError("<unk> token must be in vocab")
125
126
127
128
129
        assert self.pad_token in self.encoder
        self.decoder = {v: k for k, v in self.encoder.items()}

        self.source_lang = source_lang
        self.target_lang = target_lang
130
131
        self.supported_language_codes: list = [k for k in self.encoder if k.startswith(">>") and k.endswith("<<")]
        self.spm_files = [source_spm, target_spm]
132
133

        # load SentencePiece model for pre-processing
134
135
136
        self.spm_source = load_spm(source_spm)
        self.spm_target = load_spm(target_spm)
        self.current_spm = self.spm_source
137

138
139
        # Multilingual target side: default to using first supported language code.

140
141
142
        self._setup_normalizer()

    def _setup_normalizer(self):
143
        try:
144
            from sacremoses import MosesPunctNormalizer
145

146
147
148
            self.punc_normalizer = MosesPunctNormalizer(self.source_lang).normalize
        except (ImportError, FileNotFoundError):
            warnings.warn("Recommended: pip install sacremoses.")
149
150
            self.punc_normalizer = lambda x: x

151
152
153
154
    def normalize(self, x: str) -> str:
        """Cover moses empty string edge case. They return empty list for '' input!"""
        return self.punc_normalizer(x) if x else ""

155
    def _convert_token_to_id(self, token):
156
        return self.encoder.get(token, self.encoder[self.unk_token])
157

158
159
160
161
162
163
    def remove_language_code(self, text: str):
        """Remove language codes like <<fr>> before sentencepiece"""
        match = self.language_code_re.match(text)
        code: list = [match.group(0)] if match else []
        return code, self.language_code_re.sub("", text)

164
    def _tokenize(self, text: str) -> List[str]:
165
166
167
        code, text = self.remove_language_code(text)
        pieces = self.current_spm.EncodeAsPieces(text)
        return code + pieces
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183

    def _convert_id_to_token(self, index: int) -> str:
        """Converts an index (integer) in a token (str) using the encoder."""
        return self.decoder.get(index, self.unk_token)

    def convert_tokens_to_string(self, tokens: List[str]) -> str:
        """Uses target language sentencepiece model"""
        return self.spm_target.DecodePieces(tokens)

    def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]:
        """Build model inputs from a sequence by appending eos_token_id."""
        if token_ids_1 is None:
            return token_ids_0 + [self.eos_token_id]
        # We don't expect to process pairs, but leave the pair logic for API consistency
        return token_ids_0 + token_ids_1 + [self.eos_token_id]

184
185
186
187
188
189
    @contextmanager
    def as_target_tokenizer(self):
        """
        Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
        sequence-to-sequence models that need a slightly different processing for the labels.
        """
190
        self.current_spm = self.spm_target
191
        yield
192
        self.current_spm = self.spm_source
193
194
195
196
197

    @property
    def vocab_size(self) -> int:
        return len(self.encoder)

198
    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
199
200
        save_dir = Path(save_directory)
        assert save_dir.is_dir(), f"{save_directory} should be a directory"
201
202
203
204
        save_json(
            self.encoder,
            save_dir / ((filename_prefix + "-" if filename_prefix else "") + self.vocab_files_names["vocab"]),
        )
205

206
        for orig, f in zip(["source.spm", "target.spm"], self.spm_files):
207
            dest_path = save_dir / ((filename_prefix + "-" if filename_prefix else "") + Path(f).name)
208
            if not dest_path.exists():
209
210
                copyfile(f, save_dir / orig)

211
212
213
        return tuple(
            save_dir / ((filename_prefix + "-" if filename_prefix else "") + f) for f in self.vocab_files_names
        )
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261

    def get_vocab(self) -> Dict:
        vocab = self.encoder.copy()
        vocab.update(self.added_tokens_encoder)
        return vocab

    def __getstate__(self) -> Dict:
        state = self.__dict__.copy()
        state.update({k: None for k in ["spm_source", "spm_target", "current_spm", "punc_normalizer"]})
        return state

    def __setstate__(self, d: Dict) -> None:
        self.__dict__ = d
        self.spm_source, self.spm_target = (load_spm(f) for f in self.spm_files)
        self.current_spm = self.spm_source
        self._setup_normalizer()

    def num_special_tokens_to_add(self, **unused):
        """Just EOS"""
        return 1

    def _special_token_mask(self, seq):
        all_special_ids = set(self.all_special_ids)  # call it once instead of inside list comp
        all_special_ids.remove(self.unk_token_id)  # <unk> is only sometimes special
        return [1 if x in all_special_ids else 0 for x in seq]

    def get_special_tokens_mask(
        self, token_ids_0: List, token_ids_1: Optional[List] = None, already_has_special_tokens: bool = False
    ) -> List[int]:
        """Get list where entries are [1] if a token is [eos] or [pad] else 0."""
        if already_has_special_tokens:
            return self._special_token_mask(token_ids_0)
        elif token_ids_1 is None:
            return self._special_token_mask(token_ids_0) + [1]
        else:
            return self._special_token_mask(token_ids_0 + token_ids_1) + [1]


def load_spm(path: str) -> sentencepiece.SentencePieceProcessor:
    spm = sentencepiece.SentencePieceProcessor()
    spm.Load(path)
    return spm


def save_json(data, path: str) -> None:
    with open(path, "w") as f:
        json.dump(data, f, indent=2)

262
263
264
265

def load_json(path: str) -> Union[Dict, List]:
    with open(path, "r") as f:
        return json.load(f)