Unverified Commit 43cb03a9 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

MarianTokenizer.prepare_translation_batch uses new tokenizer API (#5182)

parent 13deb95a
...@@ -129,6 +129,8 @@ class MarianTokenizer(PreTrainedTokenizer): ...@@ -129,6 +129,8 @@ class MarianTokenizer(PreTrainedTokenizer):
max_length: Optional[int] = None, max_length: Optional[int] = None,
pad_to_max_length: bool = True, pad_to_max_length: bool = True,
return_tensors: str = "pt", return_tensors: str = "pt",
truncation_strategy="only_first",
padding="longest",
) -> BatchEncoding: ) -> BatchEncoding:
"""Prepare model inputs for translation. For best performance, translate one sentence at a time. """Prepare model inputs for translation. For best performance, translate one sentence at a time.
Arguments: Arguments:
...@@ -147,24 +149,21 @@ class MarianTokenizer(PreTrainedTokenizer): ...@@ -147,24 +149,21 @@ class MarianTokenizer(PreTrainedTokenizer):
raise ValueError(f"found empty string in src_texts: {src_texts}") raise ValueError(f"found empty string in src_texts: {src_texts}")
self.current_spm = self.spm_source self.current_spm = self.spm_source
src_texts = [self.normalize(t) for t in src_texts] # this does not appear to do much src_texts = [self.normalize(t) for t in src_texts] # this does not appear to do much
model_inputs: BatchEncoding = self.batch_encode_plus( tokenizer_kwargs = dict(
src_texts,
add_special_tokens=True, add_special_tokens=True,
return_tensors=return_tensors, return_tensors=return_tensors,
max_length=max_length, max_length=max_length,
pad_to_max_length=pad_to_max_length, pad_to_max_length=pad_to_max_length,
truncation_strategy=truncation_strategy,
padding=padding,
) )
model_inputs: BatchEncoding = self(src_texts, **tokenizer_kwargs)
if tgt_texts is None: if tgt_texts is None:
return model_inputs return model_inputs
self.current_spm = self.spm_target self.current_spm = self.spm_target
decoder_inputs: BatchEncoding = self.batch_encode_plus( decoder_inputs: BatchEncoding = self(tgt_texts, **tokenizer_kwargs)
tgt_texts,
add_special_tokens=True,
return_tensors=return_tensors,
max_length=max_length,
pad_to_max_length=pad_to_max_length,
)
for k, v in decoder_inputs.items(): for k, v in decoder_inputs.items():
model_inputs[f"decoder_{k}"] = v model_inputs[f"decoder_{k}"] = v
self.current_spm = self.spm_source self.current_spm = self.spm_source
......
...@@ -24,6 +24,7 @@ from transformers.tokenization_marian import MarianTokenizer, save_json, vocab_f ...@@ -24,6 +24,7 @@ from transformers.tokenization_marian import MarianTokenizer, save_json, vocab_f
from transformers.tokenization_utils import BatchEncoding from transformers.tokenization_utils import BatchEncoding
from .test_tokenization_common import TokenizerTesterMixin from .test_tokenization_common import TokenizerTesterMixin
from .utils import _torch_available
SAMPLE_SP = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model") SAMPLE_SP = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model")
...@@ -31,6 +32,7 @@ SAMPLE_SP = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/t ...@@ -31,6 +32,7 @@ SAMPLE_SP = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/t
mock_tokenizer_config = {"target_lang": "fi", "source_lang": "en"} mock_tokenizer_config = {"target_lang": "fi", "source_lang": "en"}
zh_code = ">>zh<<" zh_code = ">>zh<<"
ORG_NAME = "Helsinki-NLP/" ORG_NAME = "Helsinki-NLP/"
FRAMEWORK = "pt" if _torch_available else "tf"
class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase): class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
...@@ -72,3 +74,20 @@ class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -72,3 +74,20 @@ class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
contents = [x.name for x in Path(save_dir).glob("*")] contents = [x.name for x in Path(save_dir).glob("*")]
self.assertIn("source.spm", contents) self.assertIn("source.spm", contents)
MarianTokenizer.from_pretrained(save_dir) MarianTokenizer.from_pretrained(save_dir)
def test_outputs_not_longer_than_maxlen(self):
tok = self.get_tokenizer()
batch = tok.prepare_translation_batch(
["I am a small frog" * 1000, "I am a small frog"], return_tensors=FRAMEWORK
)
self.assertIsInstance(batch, BatchEncoding)
self.assertEqual(batch.input_ids.shape, (2, 512))
def test_outputs_can_be_shorter(self):
tok = self.get_tokenizer()
batch_smaller = tok.prepare_translation_batch(
["I am a tiny frog", "I am a small frog"], return_tensors=FRAMEWORK
)
self.assertIsInstance(batch_smaller, BatchEncoding)
self.assertEqual(batch_smaller.input_ids.shape, (2, 10))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment