Commit f789a52f authored by YdrMaster's avatar YdrMaster
Browse files

issue/291/build: 为 cudnn 增加一个编译选项,并允许 cuda 使用标记弃用的函数


Signed-off-by: default avatarYdrMaster <ydrml@hotmail.com>
parent 14a53dee
......@@ -34,6 +34,7 @@ infiniStatus_t Handle::Internal::useCublas(cudaStream_t stream, const Fn<cublasH
return INFINI_STATUS_SUCCESS;
}
#ifdef ENABLE_CUDNN_API
infiniStatus_t Handle::Internal::useCudnn(cudaStream_t stream, const Fn<cudnnHandle_t> &f) const {
auto handle = dnn_handles.pop();
if (!handle) {
......@@ -44,6 +45,7 @@ infiniStatus_t Handle::Internal::useCudnn(cudaStream_t stream, const Fn<cudnnHan
dnn_handles.push(std::move(*handle));
return INFINI_STATUS_SUCCESS;
}
#endif
int Handle::Internal::warpSize() const { return _warp_size; }
int Handle::Internal::maxThreadsPerBlock() const { return _max_threads_per_block; }
......@@ -54,6 +56,7 @@ int Handle::Internal::gridSizeX() const { return _grid_size[0]; }
int Handle::Internal::gridSizeY() const { return _grid_size[1]; }
int Handle::Internal::gridSizeZ() const { return _grid_size[2]; }
#ifdef ENABLE_CUDNN_API
cudnnDataType_t getCudnnDtype(infiniDtype_t dt) {
switch (dt) {
case INFINI_DTYPE_F16:
......@@ -78,6 +81,7 @@ cudnnDataType_t getCudnnDtype(infiniDtype_t dt) {
return CUDNN_DATA_FLOAT;
}
}
#endif
namespace nvidia {
......
......@@ -6,7 +6,9 @@
namespace device::cuda {
#ifdef ENABLE_CUDNN_API
cudnnDataType_t getCudnnDtype(infiniDtype_t dt);
#endif
} // namespace device::cuda
......
......@@ -5,9 +5,12 @@
#include "../pool.h"
#include "cuda_handle.h"
#include <cublas_v2.h>
#include <cudnn.h>
#include <functional>
#ifdef ENABLE_CUDNN_API
#include <cudnn.h>
#endif
#define CHECK_CUBLAS(API) CHECK_INTERNAL(API, CUBLAS_STATUS_SUCCESS)
#define CHECK_CUDNN(API) CHECK_INTERNAL(API, CUDNN_STATUS_SUCCESS)
......@@ -15,7 +18,9 @@ namespace device::cuda {
class Handle::Internal {
Pool<cublasHandle_t> blas_handles;
#ifdef ENABLE_CUDNN_API
Pool<cudnnHandle_t> dnn_handles;
#endif
int _warp_size,
_max_threads_per_block,
......@@ -29,7 +34,9 @@ public:
Internal(int);
infiniStatus_t useCublas(cudaStream_t stream, const Fn<cublasHandle_t> &f) const;
#ifdef ENABLE_CUDNN_API
infiniStatus_t useCudnn(cudaStream_t stream, const Fn<cudnnHandle_t> &f) const;
#endif
int warpSize() const;
int maxThreadsPerBlock() const;
......
......@@ -52,17 +52,14 @@ if has_config("nv-gpu") then
includes("xmake/cuda.lua")
end
-- 天数智芯
option("iluvatar-gpu")
set_default(false)
option("cudnn")
set_default(true)
set_showmenu(true)
set_description("Whether to complie implementations for Iluvatar GPU")
set_description("Whether to complie cudnn for Nvidia GPU")
option_end()
if has_config("iluvatar-gpu") then
add_defines("ENABLE_CUDA_API")
add_defines("ENABLE_ILUVATAR_CUDA_API")
includes("xmake/iluvatar.lua")
if has_config("cudnn") then
add_defines("ENABLE_CUDNN_API")
end
-- 寒武纪
......@@ -89,6 +86,19 @@ if has_config("ascend-npu") then
includes("xmake/ascend.lua")
end
-- 天数智芯
option("iluvatar-gpu")
set_default(false)
set_showmenu(true)
set_description("Whether to complie implementations for Iluvatar GPU")
option_end()
if has_config("iluvatar-gpu") then
add_defines("ENABLE_CUDA_API")
add_defines("ENABLE_ILUVATAR_CUDA_API")
includes("xmake/iluvatar.lua")
end
-- 沐曦
option("metax-gpu")
set_default(false)
......
......@@ -15,8 +15,11 @@ target("infiniop-cuda")
set_policy("build.cuda.devlink", true)
set_toolchains("cuda")
add_links("cuda", "cublas", "cudnn")
add_links("cuda", "cublas")
add_linkdirs(CUDA_ROOT .. "/lib64/stubs")
if has_config("cudnn") then
add_links("cudnn")
end
add_cugencodes("native")
if is_plat("windows") then
......@@ -38,6 +41,8 @@ target("infiniop-cuda")
end
end
add_cuflags("-Xcompiler=-Wno-error=deprecated-declarations")
set_languages("cxx17")
add_files("../src/infiniop/devices/cuda/*.cu", "../src/infiniop/ops/*/cuda/*.cu", "../build/ninetoothed/*.c")
target_end()
......@@ -92,5 +97,5 @@ target("infiniccl-cuda")
end
end
set_languages("cxx17")
target_end()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment