"tests/vscode:/vscode.git/clone" did not exist on "1e651ca2c9f12bdcc5d63da8830847706e186f22"
Unverified Commit c58e6c12 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[marian tests ] pass device to pipeline (#4815)

parent ddf9a3df
...@@ -233,7 +233,8 @@ class TestMarian_en_ROMANCE(MarianIntegrationTest): ...@@ -233,7 +233,8 @@ class TestMarian_en_ROMANCE(MarianIntegrationTest):
self.tokenizer.prepare_translation_batch([""]) self.tokenizer.prepare_translation_batch([""])
def test_pipeline(self): def test_pipeline(self):
pipeline = TranslationPipeline(self.model, self.tokenizer, framework="pt") device = 0 if torch_device == "cuda" else -1
pipeline = TranslationPipeline(self.model, self.tokenizer, framework="pt", device=device)
output = pipeline(self.src_text) output = pipeline(self.src_text)
self.assertEqual(self.expected_text, [x["translation_text"] for x in output]) self.assertEqual(self.expected_text, [x["translation_text"] for x in output])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment