Unverified Commit e2803b16 authored by cyanguwa's avatar cyanguwa Committed by GitHub
Browse files

Update FindCUDNN.cmake for cuDNN 9 (#640)



* update cudnn cmake for v9
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add back license information
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
parent 70bd26e8
...@@ -8,25 +8,29 @@ find_path( ...@@ -8,25 +8,29 @@ find_path(
CUDNN_INCLUDE_DIR cudnn.h CUDNN_INCLUDE_DIR cudnn.h
HINTS $ENV{CUDNN_PATH} ${CUDNN_PATH} ${CUDAToolkit_INCLUDE_DIRS} HINTS $ENV{CUDNN_PATH} ${CUDNN_PATH} ${CUDAToolkit_INCLUDE_DIRS}
PATH_SUFFIXES include PATH_SUFFIXES include
REQUIRED
) )
function(find_cudnn_library NAME) file(READ "${CUDNN_INCLUDE_DIR}/cudnn_version.h" cudnn_version_header)
string(TOUPPER ${NAME} UPPERCASE_NAME) string(REGEX MATCH "#define CUDNN_MAJOR [1-9]+" macrodef "${cudnn_version_header}")
string(REGEX MATCH "[1-9]+" CUDNN_MAJOR_VERSION "${macrodef}")
function(find_cudnn_library NAME)
find_library( find_library(
${UPPERCASE_NAME}_LIBRARY ${NAME} ${NAME}_LIBRARY ${NAME} "lib${NAME}.so.${CUDNN_MAJOR_VERSION}"
HINTS $ENV{CUDNN_PATH} ${CUDNN_PATH} ${CUDAToolkit_LIBRARY_DIR} HINTS $ENV{CUDNN_PATH} ${CUDNN_PATH} ${CUDAToolkit_LIBRARY_DIR}
PATH_SUFFIXES lib64 lib/x64 lib PATH_SUFFIXES lib64 lib/x64 lib
REQUIRED
) )
if(${UPPERCASE_NAME}_LIBRARY) if(${NAME}_LIBRARY)
add_library(CUDNN::${NAME} UNKNOWN IMPORTED) add_library(CUDNN::${NAME} UNKNOWN IMPORTED)
set_target_properties( set_target_properties(
CUDNN::${NAME} PROPERTIES CUDNN::${NAME} PROPERTIES
INTERFACE_INCLUDE_DIRECTORIES ${CUDNN_INCLUDE_DIR} INTERFACE_INCLUDE_DIRECTORIES ${CUDNN_INCLUDE_DIR}
IMPORTED_LOCATION ${${UPPERCASE_NAME}_LIBRARY} IMPORTED_LOCATION ${${NAME}_LIBRARY}
) )
message(STATUS "${NAME} found at ${${UPPERCASE_NAME}_LIBRARY}.") message(STATUS "${NAME} found at ${${NAME}_LIBRARY}.")
else() else()
message(STATUS "${NAME} not found.") message(STATUS "${NAME} not found.")
endif() endif()
...@@ -35,22 +39,16 @@ function(find_cudnn_library NAME) ...@@ -35,22 +39,16 @@ function(find_cudnn_library NAME)
endfunction() endfunction()
find_cudnn_library(cudnn) find_cudnn_library(cudnn)
find_cudnn_library(cudnn_adv_infer)
find_cudnn_library(cudnn_adv_train)
find_cudnn_library(cudnn_cnn_infer)
find_cudnn_library(cudnn_cnn_train)
find_cudnn_library(cudnn_ops_infer)
find_cudnn_library(cudnn_ops_train)
include (FindPackageHandleStandardArgs) include (FindPackageHandleStandardArgs)
find_package_handle_standard_args( find_package_handle_standard_args(
CUDNN REQUIRED_VARS LIBRARY REQUIRED_VARS
CUDNN_INCLUDE_DIR CUDNN_LIBRARY CUDNN_INCLUDE_DIR cudnn_LIBRARY
) )
if(CUDNN_INCLUDE_DIR AND CUDNN_LIBRARY) if(CUDNN_INCLUDE_DIR AND cudnn_LIBRARY)
message(STATUS "cuDNN: ${CUDNN_LIBRARY}") message(STATUS "cuDNN: ${cudnn_LIBRARY}")
message(STATUS "cuDNN: ${CUDNN_INCLUDE_DIR}") message(STATUS "cuDNN: ${CUDNN_INCLUDE_DIR}")
set(CUDNN_FOUND ON CACHE INTERNAL "cuDNN Library Found") set(CUDNN_FOUND ON CACHE INTERNAL "cuDNN Library Found")
...@@ -69,6 +67,20 @@ target_include_directories( ...@@ -69,6 +67,20 @@ target_include_directories(
) )
target_link_libraries( target_link_libraries(
CUDNN::cudnn_all
INTERFACE
CUDNN::cudnn
)
if(CUDNN_MAJOR_VERSION EQUAL 8)
find_cudnn_library(cudnn_adv_infer)
find_cudnn_library(cudnn_adv_train)
find_cudnn_library(cudnn_cnn_infer)
find_cudnn_library(cudnn_cnn_train)
find_cudnn_library(cudnn_ops_infer)
find_cudnn_library(cudnn_ops_train)
target_link_libraries(
CUDNN::cudnn_all CUDNN::cudnn_all
INTERFACE INTERFACE
CUDNN::cudnn_adv_train CUDNN::cudnn_adv_train
...@@ -77,5 +89,25 @@ target_link_libraries( ...@@ -77,5 +89,25 @@ target_link_libraries(
CUDNN::cudnn_adv_infer CUDNN::cudnn_adv_infer
CUDNN::cudnn_cnn_infer CUDNN::cudnn_cnn_infer
CUDNN::cudnn_ops_infer CUDNN::cudnn_ops_infer
CUDNN::cudnn )
) elseif(CUDNN_MAJOR_VERSION EQUAL 9)
find_cudnn_library(cudnn_cnn)
find_cudnn_library(cudnn_adv)
find_cudnn_library(cudnn_graph)
find_cudnn_library(cudnn_ops)
find_cudnn_library(cudnn_engines_runtime_compiled)
find_cudnn_library(cudnn_engines_precompiled)
find_cudnn_library(cudnn_heuristic)
target_link_libraries(
CUDNN::cudnn_all
INTERFACE
CUDNN::cudnn_adv
CUDNN::cudnn_ops
CUDNN::cudnn_cnn
CUDNN::cudnn_graph
CUDNN::cudnn_engines_runtime_compiled
CUDNN::cudnn_engines_precompiled
CUDNN::cudnn_heuristic
)
endif()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment