Unverified Commit 0cb5f0fd authored by czkkkkkk's avatar czkkkkkk Committed by GitHub
Browse files

[Sparse] Be compatible with Pytorch 1.13 and later version (#4935)

parent c53deb26
......@@ -36,6 +36,8 @@ dgl_option(USE_HDFS "Build with HDFS support" OFF) # Set env HADOOP_HDFS_HOME if
dgl_option(REBUILD_LIBXSMM "Clean LIBXSMM build cache at every build" OFF) # Set env HADOOP_HDFS_HOME if needed
dgl_option(USE_EPOLL "Build with epoll for socket communicator" ON)
dgl_option(TP_BUILD_LIBUV "Build libuv together with tensorpipe (only impacts Linux)" ON)
dgl_option(BUILD_SPARSE "Build DGL sparse library" OFF)
dgl_option(TORCH_PYTHON_INTERPS "Python interpreter used to build tensoradapter and DGL sparse library" python3)
# Set debug compile option for gdb, only happens when -DCMAKE_BUILD_TYPE=DEBUG
if (NOT MSVC)
......
......@@ -13,6 +13,9 @@ message(STATUS "find_cmake.py output: ${TORCH_PREFIX_VER}")
list(GET TORCH_PREFIX_VER 0 TORCH_PREFIX)
list(GET TORCH_PREFIX_VER 1 TORCH_VER)
message(STATUS "Configuring for PyTorch ${TORCH_VER}")
string(REPLACE "." ";" TORCH_VERSION_LIST ${TORCH_VER})
list(GET TORCH_VERSION_LIST 0 TORCH_VERSION_MAJOR)
list(GET TORCH_VERSION_LIST 1 TORCH_VERSION_MINOR)
if(USE_CUDA)
add_definitions(-DDGL_USE_CUDA)
......@@ -34,6 +37,8 @@ add_library(${LIB_DGL_SPARSE_NAME} SHARED ${SPARSE_SRC} ${SPARSE_HEADERS})
target_include_directories(
${LIB_DGL_SPARSE_NAME} PRIVATE ${SPARSE_DIR} ${SPARSE_HEADERS})
target_link_libraries(${LIB_DGL_SPARSE_NAME} "${TORCH_LIBRARIES}")
target_compile_definitions(${LIB_DGL_SPARSE_NAME} PRIVATE TORCH_VERSION_MAJOR=${TORCH_VERSION_MAJOR})
target_compile_definitions(${LIB_DGL_SPARSE_NAME} PRIVATE TORCH_VERSION_MINOR=${TORCH_VERSION_MINOR})
target_include_directories(${LIB_DGL_SPARSE_NAME} PRIVATE "${CMAKE_SOURCE_DIR}/third_party/dmlc-core/include")
target_link_libraries(${LIB_DGL_SPARSE_NAME} dmlc)
......
......@@ -39,4 +39,14 @@
#undef DLOG
#undef LOG_IF
// For Pytorch version later than 1.12, redefine CHECK_* to TORCH_CHECK_*.
#if !(TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR <= 12)
#define CHECK_EQ(val1, val2) TORCH_CHECK_EQ(val1, val2)
#define CHECK_NE(val1, val2) TORCH_CHECK_NE(val1, val2)
#define CHECK_LE(val1, val2) TORCH_CHECK_LE(val1, val2)
#define CHECK_LT(val1, val2) TORCH_CHECK_LT(val1, val2)
#define CHECK_GE(val1, val2) TORCH_CHECK_GE(val1, val2)
#define CHECK_GT(val1, val2) TORCH_CHECK_GT(val1, val2)
#endif
#endif // SPARSE_DGL_HEADERS_H_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment