Unverified Commit 50e7a3da authored by cyanguwa's avatar cyanguwa Committed by GitHub
Browse files

Replace FindCUDNN.cmake with cudnn-frontend's cuDNN.cmake (#831)



* use 3rdparty cudnn-frontend cmake to find cuDNN
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add check for 3rdparty/cudnn-frontend module
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* switch order of CUDA and cuDNN
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
parent 90c267f2
...@@ -19,9 +19,19 @@ if (CMAKE_BUILD_TYPE STREQUAL "Debug") ...@@ -19,9 +19,19 @@ if (CMAKE_BUILD_TYPE STREQUAL "Debug")
set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -G") set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -G")
endif() endif()
list(PREPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
find_package(CUDAToolkit REQUIRED cublas nvToolsExt) find_package(CUDAToolkit REQUIRED cublas nvToolsExt)
find_package(CUDNN REQUIRED cudnn)
# Check for cuDNN frontend API
set(CUDNN_FRONTEND_INCLUDE_DIR
"${CMAKE_SOURCE_DIR}/../3rdparty/cudnn-frontend/include")
if(NOT EXISTS "${CUDNN_FRONTEND_INCLUDE_DIR}")
message(FATAL_ERROR
"Could not find cuDNN frontend API. "
"Try running 'git submodule update --init --recursive' "
"within the Transformer Engine source.")
endif()
include(${CMAKE_SOURCE_DIR}/../3rdparty/cudnn-frontend/cmake/cuDNN.cmake)
find_package(Python COMPONENTS Interpreter Development.Module REQUIRED) find_package(Python COMPONENTS Interpreter Development.Module REQUIRED)
include_directories(${PROJECT_SOURCE_DIR}) include_directories(${PROJECT_SOURCE_DIR})
......
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
add_library(CUDNN::cudnn_all INTERFACE IMPORTED)
find_path(
CUDNN_INCLUDE_DIR cudnn.h
HINTS $ENV{CUDNN_PATH} ${CUDNN_PATH} ${CUDAToolkit_INCLUDE_DIRS}
PATH_SUFFIXES include
REQUIRED
)
file(READ "${CUDNN_INCLUDE_DIR}/cudnn_version.h" cudnn_version_header)
string(REGEX MATCH "#define CUDNN_MAJOR [1-9]+" macrodef "${cudnn_version_header}")
string(REGEX MATCH "[1-9]+" CUDNN_MAJOR_VERSION "${macrodef}")
function(find_cudnn_library NAME)
find_library(
${NAME}_LIBRARY ${NAME} "lib${NAME}.so.${CUDNN_MAJOR_VERSION}"
HINTS $ENV{CUDNN_PATH} ${CUDNN_PATH} ${CUDAToolkit_LIBRARY_DIR}
PATH_SUFFIXES lib64 lib/x64 lib
REQUIRED
)
if(${NAME}_LIBRARY)
add_library(CUDNN::${NAME} UNKNOWN IMPORTED)
set_target_properties(
CUDNN::${NAME} PROPERTIES
INTERFACE_INCLUDE_DIRECTORIES ${CUDNN_INCLUDE_DIR}
IMPORTED_LOCATION ${${NAME}_LIBRARY}
)
message(STATUS "${NAME} found at ${${NAME}_LIBRARY}.")
else()
message(STATUS "${NAME} not found.")
endif()
endfunction()
find_cudnn_library(cudnn)
include (FindPackageHandleStandardArgs)
find_package_handle_standard_args(
LIBRARY REQUIRED_VARS
CUDNN_INCLUDE_DIR cudnn_LIBRARY
)
if(CUDNN_INCLUDE_DIR AND cudnn_LIBRARY)
message(STATUS "cuDNN: ${cudnn_LIBRARY}")
message(STATUS "cuDNN: ${CUDNN_INCLUDE_DIR}")
set(CUDNN_FOUND ON CACHE INTERNAL "cuDNN Library Found")
else()
set(CUDNN_FOUND OFF CACHE INTERNAL "cuDNN Library Not Found")
endif()
target_include_directories(
CUDNN::cudnn_all
INTERFACE
$<INSTALL_INTERFACE:include>
$<BUILD_INTERFACE:${CUDNN_INCLUDE_DIR}>
)
target_link_libraries(
CUDNN::cudnn_all
INTERFACE
CUDNN::cudnn
)
if(CUDNN_MAJOR_VERSION EQUAL 8)
find_cudnn_library(cudnn_adv_infer)
find_cudnn_library(cudnn_adv_train)
find_cudnn_library(cudnn_cnn_infer)
find_cudnn_library(cudnn_cnn_train)
find_cudnn_library(cudnn_ops_infer)
find_cudnn_library(cudnn_ops_train)
target_link_libraries(
CUDNN::cudnn_all
INTERFACE
CUDNN::cudnn_adv_train
CUDNN::cudnn_ops_train
CUDNN::cudnn_cnn_train
CUDNN::cudnn_adv_infer
CUDNN::cudnn_cnn_infer
CUDNN::cudnn_ops_infer
)
elseif(CUDNN_MAJOR_VERSION EQUAL 9)
find_cudnn_library(cudnn_cnn)
find_cudnn_library(cudnn_adv)
find_cudnn_library(cudnn_graph)
find_cudnn_library(cudnn_ops)
find_cudnn_library(cudnn_engines_runtime_compiled)
find_cudnn_library(cudnn_engines_precompiled)
find_cudnn_library(cudnn_heuristic)
target_link_libraries(
CUDNN::cudnn_all
INTERFACE
CUDNN::cudnn_adv
CUDNN::cudnn_ops
CUDNN::cudnn_cnn
CUDNN::cudnn_graph
CUDNN::cudnn_engines_runtime_compiled
CUDNN::cudnn_engines_precompiled
CUDNN::cudnn_heuristic
)
endif()
...@@ -40,16 +40,6 @@ add_library(transformer_engine SHARED ${transformer_engine_SOURCES}) ...@@ -40,16 +40,6 @@ add_library(transformer_engine SHARED ${transformer_engine_SOURCES})
target_include_directories(transformer_engine PUBLIC target_include_directories(transformer_engine PUBLIC
"${CMAKE_CURRENT_SOURCE_DIR}/include") "${CMAKE_CURRENT_SOURCE_DIR}/include")
# Check for cuDNN frontend API
set(CUDNN_FRONTEND_INCLUDE_DIR
"${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cudnn-frontend/include")
if(NOT EXISTS "${CUDNN_FRONTEND_INCLUDE_DIR}")
message(FATAL_ERROR
"Could not find cuDNN frontend API. "
"Try running 'git submodule update --init --recursive' "
"within the Transformer Engine source.")
endif()
# Configure dependencies # Configure dependencies
target_link_libraries(transformer_engine PUBLIC target_link_libraries(transformer_engine PUBLIC
CUDA::cublas CUDA::cublas
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment