"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "e2033d2dff6eb86b06f03e0405938a43a2608044"
Unverified Commit 4fb0241b authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[CUDA] Add CUDA11 support (#2308)



* add support for cuda 11

* fix inc bug in pytorch 1.8

* poke ci

* fix

* small fix

* try fix

* try fix
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
parent 6eda605c
...@@ -20,3 +20,6 @@ ...@@ -20,3 +20,6 @@
[submodule "third_party/phmap"] [submodule "third_party/phmap"]
path = third_party/phmap path = third_party/phmap
url = https://github.com/greg7mdp/parallel-hashmap.git url = https://github.com/greg7mdp/parallel-hashmap.git
[submodule "third_party/thrust"]
path = third_party/thrust
url = https://github.com/NVIDIA/thrust.git
...@@ -38,6 +38,17 @@ if(USE_CUDA) ...@@ -38,6 +38,17 @@ if(USE_CUDA)
message(STATUS "Build with CUDA support") message(STATUS "Build with CUDA support")
project(dgl C CXX) project(dgl C CXX)
include(cmake/modules/CUDA.cmake) include(cmake/modules/CUDA.cmake)
if ((CUDA_VERSION_MAJOR LESS 11) OR
((CUDA_VERSION_MAJOR EQUAL 11) AND (CUDA_VERSION_MINOR EQUAL 0)))
# For cuda<11, use external CUB/Thrust library because CUB is not part of CUDA.
# For cuda==11.0, use external CUB/Thrust library because there is a bug in the
# official CUB library which causes invalid device ordinal error for DGL. The bug
# is fixed by https://github.com/NVIDIA/cub/commit/9143e47e048641aa0e6ddfd645bcd54ff1059939
# in 11.1.
message(STATUS "Detected CUDA of version ${CUDA_VERSION}. Use external CUB/Thrust library.")
list(INSERT CUDA_INCLUDE_DIRS 0 "${CMAKE_SOURCE_DIR}/third_party/thrust")
list(INSERT CUDA_INCLUDE_DIRS 0 "${CMAKE_SOURCE_DIR}/third_party/cub")
endif()
endif(USE_CUDA) endif(USE_CUDA)
# include directories # include directories
...@@ -47,7 +58,6 @@ include_directories("third_party/METIS/include/") ...@@ -47,7 +58,6 @@ include_directories("third_party/METIS/include/")
include_directories("third_party/dmlc-core/include") include_directories("third_party/dmlc-core/include")
include_directories("third_party/minigun/minigun") include_directories("third_party/minigun/minigun")
include_directories("third_party/minigun/third_party/moderngpu/src") include_directories("third_party/minigun/third_party/moderngpu/src")
include_directories("third_party/cub/")
include_directories("third_party/phmap/") include_directories("third_party/phmap/")
# initial variables # initial variables
...@@ -79,9 +89,9 @@ if(MSVC) ...@@ -79,9 +89,9 @@ if(MSVC)
endif() endif()
else(MSVC) else(MSVC)
include(CheckCXXCompilerFlag) include(CheckCXXCompilerFlag)
check_cxx_compiler_flag("-std=c++11" SUPPORT_CXX11) check_cxx_compiler_flag("-std=c++14" SUPPORT_CXX14)
set(CMAKE_C_FLAGS "-O2 -Wall -fPIC ${CMAKE_C_FLAGS}") set(CMAKE_C_FLAGS "-O2 -Wall -fPIC ${CMAKE_C_FLAGS}")
set(CMAKE_CXX_FLAGS "-O2 -Wall -fPIC -std=c++11 ${CMAKE_CXX_FLAGS}") set(CMAKE_CXX_FLAGS "-O2 -Wall -fPIC -std=c++14 ${CMAKE_CXX_FLAGS}")
endif(MSVC) endif(MSVC)
if(USE_OPENMP) if(USE_OPENMP)
......
...@@ -8,7 +8,7 @@ endif() ...@@ -8,7 +8,7 @@ endif()
###### Borrowed from MSHADOW project ###### Borrowed from MSHADOW project
include(CheckCXXCompilerFlag) include(CheckCXXCompilerFlag)
check_cxx_compiler_flag("-std=c++11" SUPPORT_CXX11) check_cxx_compiler_flag("-std=c++14" SUPPORT_CXX14)
set(dgl_known_gpu_archs "35 50 60 70") set(dgl_known_gpu_archs "35 50 60 70")
...@@ -176,7 +176,7 @@ macro(dgl_cuda_compile objlist_variable) ...@@ -176,7 +176,7 @@ macro(dgl_cuda_compile objlist_variable)
endforeach() endforeach()
if(UNIX OR APPLE) if(UNIX OR APPLE)
list(APPEND CUDA_NVCC_FLAGS -Xcompiler -fPIC) list(APPEND CUDA_NVCC_FLAGS -Xcompiler -fPIC --std=c++14)
endif() endif()
if(APPLE) if(APPLE)
...@@ -246,6 +246,8 @@ macro(dgl_config_cuda out_variable) ...@@ -246,6 +246,8 @@ macro(dgl_config_cuda out_variable)
set(NVCC_FLAGS_EXTRA "${NVCC_FLAGS_EXTRA} --expt-extended-lambda") set(NVCC_FLAGS_EXTRA "${NVCC_FLAGS_EXTRA} --expt-extended-lambda")
# suppress deprecated warning in moderngpu # suppress deprecated warning in moderngpu
set(NVCC_FLAGS_EXTRA "${NVCC_FLAGS_EXTRA} -Wno-deprecated-declarations") set(NVCC_FLAGS_EXTRA "${NVCC_FLAGS_EXTRA} -Wno-deprecated-declarations")
# for compile with c++14
set(NVCC_FLAGS_EXTRA "${NVCC_FLAGS_EXTRA} --expt-extended-lambda --std=c++14")
message(STATUS "NVCC extra flags: ${NVCC_FLAGS_EXTRA}") message(STATUS "NVCC extra flags: ${NVCC_FLAGS_EXTRA}")
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} ${NVCC_FLAGS_EXTRA}") set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} ${NVCC_FLAGS_EXTRA}")
list(APPEND CMAKE_CUDA_FLAGS "${NVCC_FLAGS_EXTRA}") list(APPEND CMAKE_CUDA_FLAGS "${NVCC_FLAGS_EXTRA}")
......
...@@ -750,7 +750,7 @@ class HeteroGraphIndex(ObjectBase): ...@@ -750,7 +750,7 @@ class HeteroGraphIndex(ObjectBase):
n = self.number_of_nodes(dsttype) n = self.number_of_nodes(dsttype)
row = F.unsqueeze(dst, 0) row = F.unsqueeze(dst, 0)
col = F.unsqueeze(eid, 0) col = F.unsqueeze(eid, 0)
idx = F.cat([row, col], dim=0) idx = F.copy_to(F.cat([row, col], dim=0), ctx)
# FIXME(minjie): data type # FIXME(minjie): data type
dat = F.ones((m,), dtype=F.float32, ctx=ctx) dat = F.ones((m,), dtype=F.float32, ctx=ctx)
inc, shuffle_idx = F.sparse_matrix(dat, ('coo', idx), (n, m)) inc, shuffle_idx = F.sparse_matrix(dat, ('coo', idx), (n, m))
...@@ -758,7 +758,7 @@ class HeteroGraphIndex(ObjectBase): ...@@ -758,7 +758,7 @@ class HeteroGraphIndex(ObjectBase):
n = self.number_of_nodes(srctype) n = self.number_of_nodes(srctype)
row = F.unsqueeze(src, 0) row = F.unsqueeze(src, 0)
col = F.unsqueeze(eid, 0) col = F.unsqueeze(eid, 0)
idx = F.cat([row, col], dim=0) idx = F.copy_to(F.cat([row, col], dim=0), ctx)
# FIXME(minjie): data type # FIXME(minjie): data type
dat = F.ones((m,), dtype=F.float32, ctx=ctx) dat = F.ones((m,), dtype=F.float32, ctx=ctx)
inc, shuffle_idx = F.sparse_matrix(dat, ('coo', idx), (n, m)) inc, shuffle_idx = F.sparse_matrix(dat, ('coo', idx), (n, m))
...@@ -775,7 +775,7 @@ class HeteroGraphIndex(ObjectBase): ...@@ -775,7 +775,7 @@ class HeteroGraphIndex(ObjectBase):
# create index # create index
row = F.unsqueeze(F.cat([src, dst], dim=0), 0) row = F.unsqueeze(F.cat([src, dst], dim=0), 0)
col = F.unsqueeze(F.cat([eid, eid], dim=0), 0) col = F.unsqueeze(F.cat([eid, eid], dim=0), 0)
idx = F.cat([row, col], dim=0) idx = F.copy_to(F.cat([row, col], dim=0), ctx)
# FIXME(minjie): data type # FIXME(minjie): data type
x = -F.ones((n_entries,), dtype=F.float32, ctx=ctx) x = -F.ones((n_entries,), dtype=F.float32, ctx=ctx)
y = F.ones((n_entries,), dtype=F.float32, ctx=ctx) y = F.ones((n_entries,), dtype=F.float32, ctx=ctx)
......
Subproject commit c3cceac115c072fb63df1836ff46d8c60d9eb304 Subproject commit a3ee304a1f8e22f278df10600df2e4b333012592
Subproject commit 0ef5c509856e12cc408f0f00ed586b4c5b1a155c
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment