Unverified Commit e0c50b27 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`pipeline`] revisit device check for pipeline (#25207)

* revisit device check for pipeline

* let's raise an error.
parent 52206066
......@@ -775,12 +775,20 @@ class Pipeline(_ScikitCompat):
self.modelcard = modelcard
self.framework = framework
# `accelerate` device map
hf_device_map = getattr(self.model, "hf_device_map", None)
if hf_device_map is not None and device is not None:
raise ValueError(
"The model has been loaded with `accelerate` and therefore cannot be moved to a specific device. Please "
"discard the `device` argument when creating your pipeline object."
)
# We shouldn't call `model.to()` for models loaded with accelerate
if self.framework == "pt" and device is not None and not (isinstance(device, int) and device < 0):
self.model.to(device)
if device is None:
# `accelerate` device map
hf_device_map = getattr(self.model, "hf_device_map", None)
if hf_device_map is not None:
# Take the first device used by `accelerate`.
device = next(iter(hf_device_map.values()))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment