"docs/source/zh/main_classes/quantization.md" did not exist on "74a3cebfa51b539bfcfa79b33686cc090b7074e8"
Unverified Commit 9d4a4550 authored by Julien Chaumond's avatar Julien Chaumond Committed by GitHub
Browse files

`pipeline` support for `device="mps"` (or any other string) (#18494)



* `pipeline` support for `device="mps"` (or any other string)

* Simplify `if` nesting

* Update src/transformers/pipelines/base.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Fix? @sgugger

* passing `attr=None` is not the same as not passing `attr` 🤯
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 0d0aada5
......@@ -422,6 +422,7 @@ def pipeline(
revision: Optional[str] = None,
use_fast: bool = True,
use_auth_token: Optional[Union[str, bool]] = None,
device: Optional[Union[int, str, "torch.device"]] = None,
device_map=None,
torch_dtype=None,
trust_remote_code: Optional[bool] = None,
......@@ -508,6 +509,9 @@ def pipeline(
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `huggingface-cli login` (stored in `~/.huggingface`).
device (`int` or `str` or `torch.device`):
Defines the device (*e.g.*, `"cpu"`, `"cuda:1"`, `"mps"`, or a GPU ordinal rank like `1`) on which this
pipeline will be allocated.
device_map (`str` or `Dict[str, Union[int, str, torch.device]`, *optional*):
Sent directly as `model_kwargs` (just a simpler shortcut). When `accelerate` library is present, set
`device_map="auto"` to compute the most optimized `device_map` automatically. [More
......@@ -811,4 +815,7 @@ def pipeline(
if feature_extractor is not None:
kwargs["feature_extractor"] = feature_extractor
if device is not None:
kwargs["device"] = device
return pipeline_class(model=model, framework=framework, task=task, **kwargs)
......@@ -704,7 +704,7 @@ PIPELINE_INIT_ARGS = r"""
Reference to the object in charge of parsing supplied pipeline parameters.
device (`int`, *optional*, defaults to -1):
Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, a positive will run the model on
the associated CUDA device id. You can pass native `torch.device` too.
the associated CUDA device id. You can pass native `torch.device` or a `str` too.
binary_output (`bool`, *optional*, defaults to `False`):
Flag indicating if the output the pipeline should happen in a binary format (i.e., pickle) or as raw text.
"""
......@@ -747,7 +747,7 @@ class Pipeline(_ScikitCompat):
framework: Optional[str] = None,
task: str = "",
args_parser: ArgumentHandler = None,
device: int = -1,
device: Union[int, str, "torch.device"] = -1,
binary_output: bool = False,
**kwargs,
):
......@@ -760,14 +760,21 @@ class Pipeline(_ScikitCompat):
self.feature_extractor = feature_extractor
self.modelcard = modelcard
self.framework = framework
if is_torch_available() and isinstance(device, torch.device):
self.device = device
if is_torch_available() and self.framework == "pt":
if isinstance(device, torch.device):
self.device = device
elif isinstance(device, str):
self.device = torch.device(device)
elif device < 0:
self.device = torch.device("cpu")
else:
self.device = torch.device("cuda:{device}")
else:
self.device = device if framework == "tf" else torch.device("cpu" if device < 0 else f"cuda:{device}")
self.device = device
self.binary_output = binary_output
# Special handling
if self.framework == "pt" and self.device.type == "cuda":
if self.framework == "pt" and self.device.type != "cpu":
self.model = self.model.to(self.device)
# Update config with task specific parameters
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment