"tests/vscode:/vscode.git/clone" did not exist on "4d9e45f3ef624cab41f605d7439862ce23ca806a"
Unverified Commit 42ad693b authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Regression pipeline device (#22190)

* Fix regression in pipeline when device=-1 is passed

* Add regression test
parent 73768147
......@@ -769,8 +769,8 @@ class Pipeline(_ScikitCompat):
self.modelcard = modelcard
self.framework = framework
if self.framework == "pt" and device is not None:
self.model = self.model.to(device=device)
if self.framework == "pt" and device is not None and not (isinstance(device, int) and device < 0):
self.model.to(device)
if device is None:
# `accelerate` device map
......
......@@ -484,6 +484,14 @@ class PipelineUtilsTest(unittest.TestCase):
outputs = list(dataset)
self.assertEqual(outputs, [[{"id": 2}, {"id": 3}, {"id": 4}, {"id": 5}]])
def test_pipeline_negative_device(self):
# To avoid regressing, pipeline used to accept device=-1
classifier = pipeline("text-generation", "hf-internal-testing/tiny-random-bert", device=-1)
expected_output = [{"generated_text": ANY(str)}]
actual_output = classifier("Test input.")
self.assertEqual(expected_output, actual_output)
@slow
@require_torch
def test_load_default_pipelines_pt(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment