Unverified Commit 42ad693b authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Regression pipeline device (#22190)

* Fix regression in pipeline when device=-1 is passed

* Add regression test
parent 73768147
...@@ -769,8 +769,8 @@ class Pipeline(_ScikitCompat): ...@@ -769,8 +769,8 @@ class Pipeline(_ScikitCompat):
self.modelcard = modelcard self.modelcard = modelcard
self.framework = framework self.framework = framework
if self.framework == "pt" and device is not None: if self.framework == "pt" and device is not None and not (isinstance(device, int) and device < 0):
self.model = self.model.to(device=device) self.model.to(device)
if device is None: if device is None:
# `accelerate` device map # `accelerate` device map
......
...@@ -484,6 +484,14 @@ class PipelineUtilsTest(unittest.TestCase): ...@@ -484,6 +484,14 @@ class PipelineUtilsTest(unittest.TestCase):
outputs = list(dataset) outputs = list(dataset)
self.assertEqual(outputs, [[{"id": 2}, {"id": 3}, {"id": 4}, {"id": 5}]]) self.assertEqual(outputs, [[{"id": 2}, {"id": 3}, {"id": 4}, {"id": 5}]])
def test_pipeline_negative_device(self):
# To avoid regressing, pipeline used to accept device=-1
classifier = pipeline("text-generation", "hf-internal-testing/tiny-random-bert", device=-1)
expected_output = [{"generated_text": ANY(str)}]
actual_output = classifier("Test input.")
self.assertEqual(expected_output, actual_output)
@slow @slow
@require_torch @require_torch
def test_load_default_pipelines_pt(self): def test_load_default_pipelines_pt(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment