Unverified Commit b361a81a authored by sherie's avatar sherie Committed by GitHub
Browse files

[Fix] Fix setup.py for torch_npu v2.1 (#2928)

parent 72084061
......@@ -27,9 +27,16 @@
#define NPU_NAME_SPACE at_npu::native
#if MMCV_WITH_XLA
#define REGISTER_NPU_IMPL(key, value) REGISTER_DEVICE_IMPL(key, XLA, value)
#else
#define REGISTER_NPU_IMPL(key, value) \
REGISTER_DEVICE_IMPL(key, PrivateUse1, value)
#endif
#define CHECK_NPU(x) \
TORCH_CHECK(x.device().type() == at::kXLA, #x " must be a NPU tensor")
TORCH_CHECK( \
x.device().type() == at::kXLA || x.device().type() == at::kPrivateUse1, \
#x " must be a NPU tensor")
#endif // PYTORCH_NPU_HELPER_HPP_
......@@ -396,6 +396,8 @@ def get_extensions():
from torch_npu.utils.cpp_extension import NpuExtension
define_macros += [('MMCV_WITH_NPU', None)]
extension = NpuExtension
if parse_version(torch.__version__) >= parse_version('2.0.0'):
define_macros += [('MMCV_WITH_XLA', None)]
except Exception:
raise ImportError('can not find any torch_npu')
# src
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment