"...composable_kernel.git" did not exist on "69a6dc749ea3b1074991af9715b103c28590f037"
Unverified Commit 4ab74245 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[cleanup/marian] pipelines test and new kwarg (#4812)

parent 875288b3
...@@ -48,13 +48,12 @@ class MarianTokenizer(PreTrainedTokenizer): ...@@ -48,13 +48,12 @@ class MarianTokenizer(PreTrainedTokenizer):
unk_token="<unk>", unk_token="<unk>",
eos_token="</s>", eos_token="</s>",
pad_token="<pad>", pad_token="<pad>",
max_len=512, model_max_length=512,
**kwargs, **kwargs
): ):
super().__init__( super().__init__(
# bos_token=bos_token, unused. Start decoding with config.decoder_start_token_id # bos_token=bos_token, unused. Start decoding with config.decoder_start_token_id
max_len=max_len, model_max_length=model_max_length,
eos_token=eos_token, eos_token=eos_token,
unk_token=unk_token, unk_token=unk_token,
pad_token=pad_token, pad_token=pad_token,
......
...@@ -38,6 +38,7 @@ if is_torch_available(): ...@@ -38,6 +38,7 @@ if is_torch_available():
convert_opus_name_to_hf_name, convert_opus_name_to_hf_name,
ORG_NAME, ORG_NAME,
) )
from transformers.pipelines import TranslationPipeline
class ModelManagementTests(unittest.TestCase): class ModelManagementTests(unittest.TestCase):
...@@ -189,6 +190,7 @@ class TestMarian_RU_FR(MarianIntegrationTest): ...@@ -189,6 +190,7 @@ class TestMarian_RU_FR(MarianIntegrationTest):
src_text = ["Он показал мне рукопись своей новой пьесы."] src_text = ["Он показал мне рукопись своей новой пьесы."]
expected_text = ["Il m'a montré le manuscrit de sa nouvelle pièce."] expected_text = ["Il m'a montré le manuscrit de sa nouvelle pièce."]
@slow
def test_batch_generation_ru_fr(self): def test_batch_generation_ru_fr(self):
self._assert_generated_batch_equal_expected() self._assert_generated_batch_equal_expected()
...@@ -199,6 +201,7 @@ class TestMarian_MT_EN(MarianIntegrationTest): ...@@ -199,6 +201,7 @@ class TestMarian_MT_EN(MarianIntegrationTest):
src_text = ["Billi messu b'mod ġentili, Ġesù fejjaq raġel li kien milqut bil - marda kerha tal - ġdiem."] src_text = ["Billi messu b'mod ġentili, Ġesù fejjaq raġel li kien milqut bil - marda kerha tal - ġdiem."]
expected_text = ["Touching gently, Jesus healed a man who was affected by the sad disease of leprosy."] expected_text = ["Touching gently, Jesus healed a man who was affected by the sad disease of leprosy."]
@slow
def test_batch_generation_mt_en(self): def test_batch_generation_mt_en(self):
self._assert_generated_batch_equal_expected() self._assert_generated_batch_equal_expected()
...@@ -229,6 +232,11 @@ class TestMarian_en_ROMANCE(MarianIntegrationTest): ...@@ -229,6 +232,11 @@ class TestMarian_en_ROMANCE(MarianIntegrationTest):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
self.tokenizer.prepare_translation_batch([""]) self.tokenizer.prepare_translation_batch([""])
def test_pipeline(self):
pipeline = TranslationPipeline(self.model, self.tokenizer, framework="pt")
output = pipeline(self.src_text)
self.assertEqual(self.expected_text, [x["translation_text"] for x in output])
@require_torch @require_torch
class TestConversionUtils(unittest.TestCase): class TestConversionUtils(unittest.TestCase):
......
...@@ -52,8 +52,7 @@ class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -52,8 +52,7 @@ class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer.save_pretrained(self.tmpdirname) tokenizer.save_pretrained(self.tmpdirname)
def get_tokenizer(self, max_len=None, **kwargs) -> MarianTokenizer: def get_tokenizer(self, max_len=None, **kwargs) -> MarianTokenizer:
# overwrite max_len=512 default return MarianTokenizer.from_pretrained(self.tmpdirname, model_max_length=max_len, **kwargs)
return MarianTokenizer.from_pretrained(self.tmpdirname, max_len=max_len, **kwargs)
def get_input_output_texts(self): def get_input_output_texts(self):
return ( return (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment