Unverified Commit efbc1c5a authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[MarianTokenizer] implement save_vocabulary and other common methods (#4389)

parent 956c4c4e
import json import json
import re import re
import warnings import warnings
from typing import Dict, List, Optional, Union from pathlib import Path
from shutil import copyfile
from typing import Dict, List, Optional, Tuple, Union
import sentencepiece import sentencepiece
...@@ -15,7 +17,7 @@ vocab_files_names = { ...@@ -15,7 +17,7 @@ vocab_files_names = {
"vocab": "vocab.json", "vocab": "vocab.json",
"tokenizer_config_file": "tokenizer_config.json", "tokenizer_config_file": "tokenizer_config.json",
} }
MODEL_NAMES = ("opus-mt-en-de",) # TODO(SS): the only required constant is vocab_files_names MODEL_NAMES = ("opus-mt-en-de",) # TODO(SS): delete this, the only required constant is vocab_files_names
PRETRAINED_VOCAB_FILES_MAP = { PRETRAINED_VOCAB_FILES_MAP = {
k: {m: f"{S3_BUCKET_PREFIX}/Helsinki-NLP/{m}/{fname}" for m in MODEL_NAMES} k: {m: f"{S3_BUCKET_PREFIX}/Helsinki-NLP/{m}/{fname}" for m in MODEL_NAMES}
for k, fname in vocab_files_names.items() for k, fname in vocab_files_names.items()
...@@ -55,14 +57,16 @@ class MarianTokenizer(PreTrainedTokenizer): ...@@ -55,14 +57,16 @@ class MarianTokenizer(PreTrainedTokenizer):
eos_token="</s>", eos_token="</s>",
pad_token="<pad>", pad_token="<pad>",
max_len=512, max_len=512,
**kwargs,
): ):
super().__init__( super().__init__(
# bos_token=bos_token, # bos_token=bos_token, unused. Start decoding with config.decoder_start_token_id
max_len=max_len, max_len=max_len,
eos_token=eos_token, eos_token=eos_token,
unk_token=unk_token, unk_token=unk_token,
pad_token=pad_token, pad_token=pad_token,
**kwargs,
) )
self.encoder = load_json(vocab) self.encoder = load_json(vocab)
if self.unk_token not in self.encoder: if self.unk_token not in self.encoder:
...@@ -72,21 +76,23 @@ class MarianTokenizer(PreTrainedTokenizer): ...@@ -72,21 +76,23 @@ class MarianTokenizer(PreTrainedTokenizer):
self.source_lang = source_lang self.source_lang = source_lang
self.target_lang = target_lang self.target_lang = target_lang
self.supported_language_codes: list = [k for k in self.encoder if k.startswith(">>") and k.endswith("<<")]
self.spm_files = [source_spm, target_spm]
# load SentencePiece model for pre-processing # load SentencePiece model for pre-processing
self.spm_source = sentencepiece.SentencePieceProcessor() self.spm_source = load_spm(source_spm)
self.spm_source.Load(source_spm) self.spm_target = load_spm(target_spm)
self.current_spm = self.spm_source
self.spm_target = sentencepiece.SentencePieceProcessor()
self.spm_target.Load(target_spm)
# Multilingual target side: default to using first supported language code. # Multilingual target side: default to using first supported language code.
self.supported_language_codes: list = [k for k in self.encoder if k.startswith(">>") and k.endswith("<<")]
self._setup_normalizer()
def _setup_normalizer(self):
try: try:
from mosestokenizer import MosesPunctuationNormalizer from mosestokenizer import MosesPunctuationNormalizer
self.punc_normalizer = MosesPunctuationNormalizer(source_lang) self.punc_normalizer = MosesPunctuationNormalizer(self.source_lang)
except ImportError: except ImportError:
warnings.warn("Recommended: pip install mosestokenizer") warnings.warn("Recommended: pip install mosestokenizer")
self.punc_normalizer = lambda x: x self.punc_normalizer = lambda x: x
...@@ -176,6 +182,65 @@ class MarianTokenizer(PreTrainedTokenizer): ...@@ -176,6 +182,65 @@ class MarianTokenizer(PreTrainedTokenizer):
def vocab_size(self) -> int: def vocab_size(self) -> int:
return len(self.encoder) return len(self.encoder)
def save_vocabulary(self, save_directory: str) -> Tuple[str]:
"""save vocab file to json and copy spm files from their original path."""
save_dir = Path(save_directory)
assert save_dir.is_dir(), f"{save_directory} should be a directory"
save_json(self.encoder, save_dir / self.vocab_files_names["vocab"])
for f in self.spm_files:
dest_path = save_dir / Path(f).name
if not dest_path.exists():
copyfile(f, save_dir / Path(f).name)
return tuple(save_dir / f for f in self.vocab_files_names)
def get_vocab(self) -> Dict:
vocab = self.encoder.copy()
vocab.update(self.added_tokens_encoder)
return vocab
def __getstate__(self) -> Dict:
state = self.__dict__.copy()
state.update({k: None for k in ["spm_source", "spm_target", "current_spm", "punc_normalizer"]})
return state
def __setstate__(self, d: Dict) -> None:
self.__dict__ = d
self.spm_source, self.spm_target = (load_spm(f) for f in self.spm_files)
self.current_spm = self.spm_source
self._setup_normalizer()
def num_special_tokens_to_add(self, **unused):
"""Just EOS"""
return 1
def _special_token_mask(self, seq):
all_special_ids = set(self.all_special_ids) # call it once instead of inside list comp
all_special_ids.remove(self.unk_token_id) # <unk> is only sometimes special
return [1 if x in all_special_ids else 0 for x in seq]
def get_special_tokens_mask(
self, token_ids_0: List, token_ids_1: Optional[List] = None, already_has_special_tokens: bool = False
) -> List[int]:
"""Get list where entries are [1] if a token is [eos] or [pad] else 0."""
if already_has_special_tokens:
return self._special_token_mask(token_ids_0)
elif token_ids_1 is None:
return self._special_token_mask(token_ids_0) + [1]
else:
return self._special_token_mask(token_ids_0 + token_ids_1) + [1]
def load_spm(path: str) -> sentencepiece.SentencePieceProcessor:
spm = sentencepiece.SentencePieceProcessor()
spm.Load(path)
return spm
def save_json(data, path: str) -> None:
with open(path, "w") as f:
json.dump(data, f, indent=2)
def load_json(path: str) -> Union[Dict, List]: def load_json(path: str) -> Union[Dict, List]:
with open(path, "r") as f: with open(path, "r") as f:
......
...@@ -129,11 +129,6 @@ class TestMarian_EN_DE_More(MarianIntegrationTest): ...@@ -129,11 +129,6 @@ class TestMarian_EN_DE_More(MarianIntegrationTest):
max_indices = logits.argmax(-1) max_indices = logits.argmax(-1)
self.tokenizer.batch_decode(max_indices) self.tokenizer.batch_decode(max_indices)
def test_tokenizer_equivalence(self):
batch = self.tokenizer.prepare_translation_batch(["I am a small frog"]).to(torch_device)
expected = [38, 121, 14, 697, 38848, 0]
self.assertListEqual(expected, batch.input_ids[0].tolist())
def test_unk_support(self): def test_unk_support(self):
t = self.tokenizer t = self.tokenizer
ids = t.prepare_translation_batch(["||"]).to(torch_device).input_ids[0].tolist() ids = t.prepare_translation_batch(["||"]).to(torch_device).input_ids[0].tolist()
......
# coding=utf-8
# Copyright 2020 Huggingface
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
from pathlib import Path
from shutil import copyfile
from transformers.tokenization_marian import MarianTokenizer, save_json, vocab_files_names
from transformers.tokenization_utils import BatchEncoding
from .test_tokenization_common import TokenizerTesterMixin
from .utils import slow
SAMPLE_SP = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model")
mock_tokenizer_config = {"target_lang": "fi", "source_lang": "en"}
zh_code = ">>zh<<"
ORG_NAME = "Helsinki-NLP/"
class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = MarianTokenizer
def setUp(self):
super().setUp()
vocab = ["</s>", "<unk>", "▁This", "▁is", "▁a", "▁t", "est", "\u0120", "<pad>"]
vocab_tokens = dict(zip(vocab, range(len(vocab))))
save_dir = Path(self.tmpdirname)
save_json(vocab_tokens, save_dir / vocab_files_names["vocab"])
save_json(mock_tokenizer_config, save_dir / vocab_files_names["tokenizer_config_file"])
if not (save_dir / vocab_files_names["source_spm"]).exists():
copyfile(SAMPLE_SP, save_dir / vocab_files_names["source_spm"])
copyfile(SAMPLE_SP, save_dir / vocab_files_names["target_spm"])
tokenizer = MarianTokenizer.from_pretrained(self.tmpdirname)
tokenizer.save_pretrained(self.tmpdirname)
def get_tokenizer(self, max_len=None, **kwargs) -> MarianTokenizer:
# overwrite max_len=512 default
return MarianTokenizer.from_pretrained(self.tmpdirname, max_len=max_len, **kwargs)
def get_input_output_texts(self):
return (
"This is a test",
"This is a test",
)
@slow
def test_tokenizer_equivalence_en_de(self):
en_de_tokenizer = MarianTokenizer.from_pretrained(f"{ORG_NAME}opus-mt-en-de")
batch = en_de_tokenizer.prepare_translation_batch(["I am a small frog"], return_tensors=None)
self.assertIsInstance(batch, BatchEncoding)
expected = [38, 121, 14, 697, 38848, 0]
self.assertListEqual(expected, batch.input_ids[0])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment