Unverified Commit a573ac74 authored by yuanwu2017's avatar yuanwu2017 Committed by GitHub
Browse files

Add the XPU device check for pipeline mode (#28326)



* Add the XPU check for pipeline mode

When setting xpu device for pipeline, It needs to use is_torch_xpu_available to load ipex and determine whether the device is available.
Signed-off-by: default avataryuanwu <yuan.wu@intel.com>

* Don't move model to device when hf_device_map isn't None

1. Don't move model to device when hf_device_map is not None
2. The device string maybe includes the device index, so use 'in'instead of equal
Signed-off-by: default avataryuanwu <yuan.wu@intel.com>

* Raise the error when xpu is not available
Signed-off-by: default avataryuanwu <yuan.wu@intel.com>

* Update src/transformers/pipelines/base.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/pipelines/base.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Modify the error message
Signed-off-by: default avataryuanwu <yuan.wu@intel.com>

* Change message format.
Signed-off-by: default avataryuanwu <yuan.wu@intel.com>

---------
Signed-off-by: default avataryuanwu <yuan.wu@intel.com>
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent 1b9a2e4c
...@@ -34,7 +34,16 @@ from ..image_processing_utils import BaseImageProcessor ...@@ -34,7 +34,16 @@ from ..image_processing_utils import BaseImageProcessor
from ..modelcard import ModelCard from ..modelcard import ModelCard
from ..models.auto.configuration_auto import AutoConfig from ..models.auto.configuration_auto import AutoConfig
from ..tokenization_utils import PreTrainedTokenizer from ..tokenization_utils import PreTrainedTokenizer
from ..utils import ModelOutput, add_end_docstrings, infer_framework, is_tf_available, is_torch_available, logging from ..utils import (
ModelOutput,
add_end_docstrings,
infer_framework,
is_tf_available,
is_torch_available,
is_torch_cuda_available,
is_torch_xpu_available,
logging,
)
GenericTensor = Union[List["GenericTensor"], "torch.Tensor", "tf.Tensor"] GenericTensor = Union[List["GenericTensor"], "torch.Tensor", "tf.Tensor"]
...@@ -792,10 +801,6 @@ class Pipeline(_ScikitCompat): ...@@ -792,10 +801,6 @@ class Pipeline(_ScikitCompat):
"discard the `device` argument when creating your pipeline object." "discard the `device` argument when creating your pipeline object."
) )
# We shouldn't call `model.to()` for models loaded with accelerate
if self.framework == "pt" and device is not None and not (isinstance(device, int) and device < 0):
self.model.to(device)
if device is None: if device is None:
if hf_device_map is not None: if hf_device_map is not None:
# Take the first device used by `accelerate`. # Take the first device used by `accelerate`.
...@@ -805,18 +810,35 @@ class Pipeline(_ScikitCompat): ...@@ -805,18 +810,35 @@ class Pipeline(_ScikitCompat):
if is_torch_available() and self.framework == "pt": if is_torch_available() and self.framework == "pt":
if isinstance(device, torch.device): if isinstance(device, torch.device):
if device.type == "xpu" and not is_torch_xpu_available(check_device=True):
raise ValueError(f'{device} is not available, you should use device="cpu" instead')
self.device = device self.device = device
elif isinstance(device, str): elif isinstance(device, str):
if "xpu" in device and not is_torch_xpu_available(check_device=True):
raise ValueError(f'{device} is not available, you should use device="cpu" instead')
self.device = torch.device(device) self.device = torch.device(device)
elif device < 0: elif device < 0:
self.device = torch.device("cpu") self.device = torch.device("cpu")
else: elif is_torch_cuda_available():
self.device = torch.device(f"cuda:{device}") self.device = torch.device(f"cuda:{device}")
elif is_torch_xpu_available(check_device=True):
self.device = torch.device(f"xpu:{device}")
else:
raise ValueError(f"{device} unrecognized or not available.")
else: else:
self.device = device if device is not None else -1 self.device = device if device is not None else -1
self.torch_dtype = torch_dtype self.torch_dtype = torch_dtype
self.binary_output = binary_output self.binary_output = binary_output
# We shouldn't call `model.to()` for models loaded with accelerate
if (
self.framework == "pt"
and self.device is not None
and not (isinstance(self.device, int) and self.device < 0)
and hf_device_map is None
):
self.model.to(self.device)
# Update config and generation_config with task specific parameters # Update config and generation_config with task specific parameters
task_specific_params = self.model.config.task_specific_params task_specific_params = self.model.config.task_specific_params
if task_specific_params is not None and task in task_specific_params: if task_specific_params is not None and task in task_specific_params:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment