"...git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "31fc9a04b6933b6298f5f8343eea7d1a64f9f9bf"
Unverified Commit 1fc12960 authored by Huazhong Ji's avatar Huazhong Ji Committed by GitHub
Browse files

get default device through `PartialState().default_device` as it has been...

get default device through `PartialState().default_device` as it has been officially released (#27256)

get default device through `PartialState().default_device` as it has
been officially released
parent e547458c
......@@ -46,6 +46,7 @@ if is_torch_available():
import torch
if is_accelerate_available():
from accelerate import PartialState
from accelerate.utils import send_to_device
......@@ -529,7 +530,7 @@ class PipelineTool(Tool):
if self.device_map is not None:
self.device = list(self.model.hf_device_map.values())[0]
else:
self.device = get_default_device()
self.device = PartialState().default_device
if self.device_map is None:
self.model.to(self.device)
......@@ -597,23 +598,6 @@ def launch_gradio_demo(tool_class: Tool):
).launch()
# TODO: Migrate to Accelerate for this once `PartialState.default_device` makes its way into a release.
def get_default_device():
logger.warning(
"`get_default_device` is deprecated and will be replaced with `accelerate`'s `PartialState().default_device` "
"in version 4.38 of 🤗 Transformers. "
)
if not is_torch_available():
raise ImportError("Please install torch in order to use this tool.")
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
return torch.device("mps")
elif torch.cuda.is_available():
return torch.device("cuda")
else:
return torch.device("cpu")
TASK_MAPPING = {
"document-question-answering": "DocumentQuestionAnsweringTool",
"image-captioning": "ImageCaptioningTool",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment