Unverified Commit aecf0c53 authored by a120092009's avatar a120092009 Committed by GitHub
Browse files

Add MLU Support. (#12629)



* Add MLU Support.

* fix comment.

* rename is_mlu_available to is_torch_mlu_available

* Apply style fixes

---------
Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
parent 0c758929
...@@ -108,6 +108,7 @@ from .import_utils import ( ...@@ -108,6 +108,7 @@ from .import_utils import (
is_tensorboard_available, is_tensorboard_available,
is_timm_available, is_timm_available,
is_torch_available, is_torch_available,
is_torch_mlu_available,
is_torch_npu_available, is_torch_npu_available,
is_torch_version, is_torch_version,
is_torch_xla_available, is_torch_xla_available,
......
...@@ -192,6 +192,7 @@ except importlib_metadata.PackageNotFoundError: ...@@ -192,6 +192,7 @@ except importlib_metadata.PackageNotFoundError:
_torch_xla_available, _torch_xla_version = _is_package_available("torch_xla") _torch_xla_available, _torch_xla_version = _is_package_available("torch_xla")
_torch_npu_available, _torch_npu_version = _is_package_available("torch_npu") _torch_npu_available, _torch_npu_version = _is_package_available("torch_npu")
_torch_mlu_available, _torch_mlu_version = _is_package_available("torch_mlu")
_transformers_available, _transformers_version = _is_package_available("transformers") _transformers_available, _transformers_version = _is_package_available("transformers")
_hf_hub_available, _hf_hub_version = _is_package_available("huggingface_hub") _hf_hub_available, _hf_hub_version = _is_package_available("huggingface_hub")
_kernels_available, _kernels_version = _is_package_available("kernels") _kernels_available, _kernels_version = _is_package_available("kernels")
...@@ -243,6 +244,10 @@ def is_torch_npu_available(): ...@@ -243,6 +244,10 @@ def is_torch_npu_available():
return _torch_npu_available return _torch_npu_available
def is_torch_mlu_available():
return _torch_mlu_available
def is_flax_available(): def is_flax_available():
return _flax_available return _flax_available
......
...@@ -20,7 +20,7 @@ import os ...@@ -20,7 +20,7 @@ import os
from typing import Callable, Dict, List, Optional, Tuple, Union from typing import Callable, Dict, List, Optional, Tuple, Union
from . import logging from . import logging
from .import_utils import is_torch_available, is_torch_npu_available, is_torch_version from .import_utils import is_torch_available, is_torch_mlu_available, is_torch_npu_available, is_torch_version
if is_torch_available(): if is_torch_available():
...@@ -286,6 +286,8 @@ def get_device(): ...@@ -286,6 +286,8 @@ def get_device():
return "xpu" return "xpu"
elif torch.backends.mps.is_available(): elif torch.backends.mps.is_available():
return "mps" return "mps"
elif is_torch_mlu_available():
return "mlu"
else: else:
return "cpu" return "cpu"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment