Unverified Commit 4144c354 authored by Fanli Lin's avatar Fanli Lin Committed by GitHub
Browse files

auto-detect device when no device is passed to pipeline (#31398)



* fix device

* Update src/transformers/pipelines/base.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* bug fix

* add warning

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent cd5f7c17
...@@ -843,6 +843,17 @@ class Pipeline(_ScikitCompat, PushToHubMixin): ...@@ -843,6 +843,17 @@ class Pipeline(_ScikitCompat, PushToHubMixin):
device = next(iter(hf_device_map.values())) device = next(iter(hf_device_map.values()))
else: else:
device = -1 device = -1
if (
is_torch_mlu_available()
or is_torch_cuda_available()
or is_torch_npu_available()
or is_torch_xpu_available(check_device=True)
or is_torch_mps_available()
):
logging.warning(
"Hardware accelerator e.g. GPU is available in the environment, but no `device` argument"
" is passed to the `Pipeline` object. Model will be on CPU."
)
if is_torch_available() and self.framework == "pt": if is_torch_available() and self.framework == "pt":
if device == -1 and self.model.device is not None: if device == -1 and self.model.device is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment