Unverified Commit b361a81a authored by sherie's avatar sherie Committed by GitHub
Browse files

[Fix] Fix setup.py for torch_npu v2.1 (#2928)

parent 72084061
...@@ -27,9 +27,16 @@ ...@@ -27,9 +27,16 @@
#define NPU_NAME_SPACE at_npu::native #define NPU_NAME_SPACE at_npu::native
#if MMCV_WITH_XLA
#define REGISTER_NPU_IMPL(key, value) REGISTER_DEVICE_IMPL(key, XLA, value) #define REGISTER_NPU_IMPL(key, value) REGISTER_DEVICE_IMPL(key, XLA, value)
#else
#define REGISTER_NPU_IMPL(key, value) \
REGISTER_DEVICE_IMPL(key, PrivateUse1, value)
#endif
#define CHECK_NPU(x) \ #define CHECK_NPU(x) \
TORCH_CHECK(x.device().type() == at::kXLA, #x " must be a NPU tensor") TORCH_CHECK( \
x.device().type() == at::kXLA || x.device().type() == at::kPrivateUse1, \
#x " must be a NPU tensor")
#endif // PYTORCH_NPU_HELPER_HPP_ #endif // PYTORCH_NPU_HELPER_HPP_
...@@ -396,6 +396,8 @@ def get_extensions(): ...@@ -396,6 +396,8 @@ def get_extensions():
from torch_npu.utils.cpp_extension import NpuExtension from torch_npu.utils.cpp_extension import NpuExtension
define_macros += [('MMCV_WITH_NPU', None)] define_macros += [('MMCV_WITH_NPU', None)]
extension = NpuExtension extension = NpuExtension
if parse_version(torch.__version__) >= parse_version('2.0.0'):
define_macros += [('MMCV_WITH_XLA', None)]
except Exception: except Exception:
raise ImportError('can not find any torch_npu') raise ImportError('can not find any torch_npu')
# src # src
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment