Unverified Commit 5f968557 authored by Huazhong Ji's avatar Huazhong Ji Committed by GitHub
Browse files

Add npu device for pipeline (#28885)



add npu device for pipeline
Co-authored-by: default avatarunit_test <test@unit.com>
parent 308d2b90
......@@ -41,6 +41,7 @@ from ..utils import (
is_tf_available,
is_torch_available,
is_torch_cuda_available,
is_torch_npu_available,
is_torch_xpu_available,
logging,
)
......@@ -852,6 +853,8 @@ class Pipeline(_ScikitCompat):
self.device = torch.device("cpu")
elif is_torch_cuda_available():
self.device = torch.device(f"cuda:{device}")
elif is_torch_npu_available():
self.device = torch.device(f"npu:{device}")
elif is_torch_xpu_available(check_device=True):
self.device = torch.device(f"xpu:{device}")
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment