Unverified Commit c58e6c12 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[marian tests ] pass device to pipeline (#4815)

parent ddf9a3df
......@@ -233,7 +233,8 @@ class TestMarian_en_ROMANCE(MarianIntegrationTest):
self.tokenizer.prepare_translation_batch([""])
def test_pipeline(self):
pipeline = TranslationPipeline(self.model, self.tokenizer, framework="pt")
device = 0 if torch_device == "cuda" else -1
pipeline = TranslationPipeline(self.model, self.tokenizer, framework="pt", device=device)
output = pipeline(self.src_text)
self.assertEqual(self.expected_text, [x["translation_text"] for x in output])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment