Unverified Commit 5939ace9 authored by Leo Jiang's avatar Leo Jiang Committed by GitHub
Browse files

Adding NPU for get device function (#11617)



* Adding device choice for npu

* Adding device choice for npu

---------
Co-authored-by: default avatarJ石页 <jiangshuo9@h-partners.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent cc59505e
...@@ -18,7 +18,7 @@ PyTorch utilities: Utilities related to PyTorch ...@@ -18,7 +18,7 @@ PyTorch utilities: Utilities related to PyTorch
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
from . import logging from . import logging
from .import_utils import is_torch_available, is_torch_version from .import_utils import is_torch_available, is_torch_npu_available, is_torch_version
if is_torch_available(): if is_torch_available():
...@@ -166,6 +166,8 @@ def get_torch_cuda_device_capability(): ...@@ -166,6 +166,8 @@ def get_torch_cuda_device_capability():
def get_device(): def get_device():
if torch.cuda.is_available(): if torch.cuda.is_available():
return "cuda" return "cuda"
elif is_torch_npu_available():
return "npu"
elif hasattr(torch, "xpu") and torch.xpu.is_available(): elif hasattr(torch, "xpu") and torch.xpu.is_available():
return "xpu" return "xpu"
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment