"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "6fe8a693ebbfa6e70b880f7c24e0cf524be6fb25"
Unverified Commit 3d495c61 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

Fix marian tokenizer save pretrained (#5043)

parent d5477baf
...@@ -40,9 +40,9 @@ class MarianTokenizer(PreTrainedTokenizer): ...@@ -40,9 +40,9 @@ class MarianTokenizer(PreTrainedTokenizer):
def __init__( def __init__(
self, self,
vocab=None, vocab,
source_spm=None, source_spm,
target_spm=None, target_spm,
source_lang=None, source_lang=None,
target_lang=None, target_lang=None,
unk_token="<unk>", unk_token="<unk>",
...@@ -59,6 +59,7 @@ class MarianTokenizer(PreTrainedTokenizer): ...@@ -59,6 +59,7 @@ class MarianTokenizer(PreTrainedTokenizer):
pad_token=pad_token, pad_token=pad_token,
**kwargs, **kwargs,
) )
assert Path(source_spm).exists(), f"cannot find spm source {source_spm}"
self.encoder = load_json(vocab) self.encoder = load_json(vocab)
if self.unk_token not in self.encoder: if self.unk_token not in self.encoder:
raise KeyError("<unk> token must be in vocab") raise KeyError("<unk> token must be in vocab")
...@@ -179,10 +180,11 @@ class MarianTokenizer(PreTrainedTokenizer): ...@@ -179,10 +180,11 @@ class MarianTokenizer(PreTrainedTokenizer):
assert save_dir.is_dir(), f"{save_directory} should be a directory" assert save_dir.is_dir(), f"{save_directory} should be a directory"
save_json(self.encoder, save_dir / self.vocab_files_names["vocab"]) save_json(self.encoder, save_dir / self.vocab_files_names["vocab"])
for f in self.spm_files: for orig, f in zip(["source.spm", "target.spm"], self.spm_files):
dest_path = save_dir / Path(f).name dest_path = save_dir / Path(f).name
if not dest_path.exists(): if not dest_path.exists():
copyfile(f, save_dir / Path(f).name) copyfile(f, save_dir / orig)
return tuple(save_dir / f for f in self.vocab_files_names) return tuple(save_dir / f for f in self.vocab_files_names)
def get_vocab(self) -> Dict: def get_vocab(self) -> Dict:
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import os import os
import tempfile
import unittest import unittest
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
...@@ -23,7 +24,6 @@ from transformers.tokenization_marian import MarianTokenizer, save_json, vocab_f ...@@ -23,7 +24,6 @@ from transformers.tokenization_marian import MarianTokenizer, save_json, vocab_f
from transformers.tokenization_utils import BatchEncoding from transformers.tokenization_utils import BatchEncoding
from .test_tokenization_common import TokenizerTesterMixin from .test_tokenization_common import TokenizerTesterMixin
from .utils import slow
SAMPLE_SP = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model") SAMPLE_SP = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model")
...@@ -60,10 +60,15 @@ class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -60,10 +60,15 @@ class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
"This is a test", "This is a test",
) )
@slow
def test_tokenizer_equivalence_en_de(self): def test_tokenizer_equivalence_en_de(self):
en_de_tokenizer = MarianTokenizer.from_pretrained(f"{ORG_NAME}opus-mt-en-de") en_de_tokenizer = MarianTokenizer.from_pretrained(f"{ORG_NAME}opus-mt-en-de")
batch = en_de_tokenizer.prepare_translation_batch(["I am a small frog"], return_tensors=None) batch = en_de_tokenizer.prepare_translation_batch(["I am a small frog"], return_tensors=None)
self.assertIsInstance(batch, BatchEncoding) self.assertIsInstance(batch, BatchEncoding)
expected = [38, 121, 14, 697, 38848, 0] expected = [38, 121, 14, 697, 38848, 0]
self.assertListEqual(expected, batch.input_ids[0]) self.assertListEqual(expected, batch.input_ids[0])
save_dir = tempfile.mkdtemp()
en_de_tokenizer.save_pretrained(save_dir)
contents = [x.name for x in Path(save_dir).glob("*")]
self.assertIn("source.spm", contents)
MarianTokenizer.from_pretrained(save_dir)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment