"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "22a0dd2ef754676ede6009cd8e86b10b053ac2cc"
Unverified Commit c58e6c12 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[marian tests ] pass device to pipeline (#4815)

parent ddf9a3df
...@@ -233,7 +233,8 @@ class TestMarian_en_ROMANCE(MarianIntegrationTest): ...@@ -233,7 +233,8 @@ class TestMarian_en_ROMANCE(MarianIntegrationTest):
self.tokenizer.prepare_translation_batch([""]) self.tokenizer.prepare_translation_batch([""])
def test_pipeline(self): def test_pipeline(self):
pipeline = TranslationPipeline(self.model, self.tokenizer, framework="pt") device = 0 if torch_device == "cuda" else -1
pipeline = TranslationPipeline(self.model, self.tokenizer, framework="pt", device=device)
output = pipeline(self.src_text) output = pipeline(self.src_text)
self.assertEqual(self.expected_text, [x["translation_text"] for x in output]) self.assertEqual(self.expected_text, [x["translation_text"] for x in output])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment