Commit a929d1c6 authored by zhuwenwen's avatar zhuwenwen
Browse files

support unittests

parent d42788f0
...@@ -26,7 +26,7 @@ Gemm::Gemm(IAllocator* allocator, cudaStream_t stream, std::string config_file) ...@@ -26,7 +26,7 @@ Gemm::Gemm(IAllocator* allocator, cudaStream_t stream, std::string config_file)
stream_ = stream; stream_ = stream;
mutex_ = new std::mutex(); // mutex per process mutex_ = new std::mutex(); // mutex per process
check_cuda_error(cublasCreate(&cublas_handle_)); check_cuda_error(cublasCreate(&cublas_handle_));
check_cuda_error(cublasLtCreate(&cublaslt_handle_)); // check_cuda_error(cublasLtCreate(&cublaslt_handle_));
check_cuda_error(cublasSetStream(cublas_handle_, stream)); check_cuda_error(cublasSetStream(cublas_handle_, stream));
if (allocator_ != nullptr) { if (allocator_ != nullptr) {
...@@ -41,7 +41,7 @@ Gemm::~Gemm() ...@@ -41,7 +41,7 @@ Gemm::~Gemm()
allocator_->free((void**)(&workspace_)); allocator_->free((void**)(&workspace_));
allocator_ = nullptr; allocator_ = nullptr;
} }
cublasLtDestroy(cublaslt_handle_); // cublasLtDestroy(cublaslt_handle_);
cublasDestroy(cublas_handle_); cublasDestroy(cublas_handle_);
delete cublas_algo_map_; delete cublas_algo_map_;
delete mutex_; delete mutex_;
...@@ -248,7 +248,8 @@ void Gemm::gemm(const GemmOp transa, ...@@ -248,7 +248,8 @@ void Gemm::gemm(const GemmOp transa,
mutex_->lock(); mutex_->lock();
// Use cublas as default in FP32 and cublasLt as default in FP16 // Use cublas as default in FP32 and cublasLt as default in FP16
bool is_fp16_compute_type = compute_type_ == TYPE_FP16; bool is_fp16_compute_type = compute_type_ == TYPE_FP16;
bool using_cublasLt = Atype == TYPE_FP16; // bool using_cublasLt = Atype == TYPE_FP16;
bool using_cublasLt = (Atype == TYPE_FP16) ? false : false;
int batch_count = 1; int batch_count = 1;
half h_alpha = (half)alpha; half h_alpha = (half)alpha;
......
...@@ -19,6 +19,10 @@ FetchContent_Declare( ...@@ -19,6 +19,10 @@ FetchContent_Declare(
googletest googletest
GIT_REPOSITORY https://github.com/google/googletest.git GIT_REPOSITORY https://github.com/google/googletest.git
GIT_TAG release-1.12.1 GIT_TAG release-1.12.1
# URL /path/to/local/googletest-release-1.12.1.zip
# URL_HASH SHA256=24564e3b712d3eb30ac9a85d92f7d720f60cc0173730ac166f27dda7fed76cb2
# export C_INCLUDE_PATH=/path/to/local/googletest-release-1.12.1/googletest/include${C_INCLUDE_PATH:+:${C_INCLUDE_PATH}}
# export CPLUS_INCLUDE_PATH=/path/to/local/googletest-release-1.12.1/include${CPLUS_INCLUDE_PATH:+:${CPLUS_INCLUDE_PATH}}
) )
add_definitions(-DTORCH_CUDA=1) add_definitions(-DTORCH_CUDA=1)
......
...@@ -74,7 +74,7 @@ bool test_context_sharing(const std::string& weight_dir, const std::string& data ...@@ -74,7 +74,7 @@ bool test_context_sharing(const std::string& weight_dir, const std::string& data
cublasLtHandle_t cublaslt_handle; cublasLtHandle_t cublaslt_handle;
check_cuda_error(cudaStreamCreate(&stream)); check_cuda_error(cudaStreamCreate(&stream));
check_cuda_error(cublasCreate(&cublas_handle)); check_cuda_error(cublasCreate(&cublas_handle));
check_cuda_error(cublasLtCreate(&cublaslt_handle)); // check_cuda_error(cublasLtCreate(&cublaslt_handle));
check_cuda_error(cublasSetStream(cublas_handle, stream)); check_cuda_error(cublasSetStream(cublas_handle, stream));
cublasAlgoMap cublas_algo_map(GEMM_CONFIG); cublasAlgoMap cublas_algo_map(GEMM_CONFIG);
......
...@@ -378,7 +378,7 @@ void testGemmConsistencyMatmul(size_t m, size_t n, size_t k) { ...@@ -378,7 +378,7 @@ void testGemmConsistencyMatmul(size_t m, size_t n, size_t k) {
cublasHandle_t cublas_handle; cublasHandle_t cublas_handle;
cublasLtHandle_t cublaslt_handle; cublasLtHandle_t cublaslt_handle;
check_cuda_error(cublasCreate(&cublas_handle)); check_cuda_error(cublasCreate(&cublas_handle));
check_cuda_error(cublasLtCreate(&cublaslt_handle)); // check_cuda_error(cublasLtCreate(&cublaslt_handle));
check_cuda_error(cublasSetStream(cublas_handle, stream)); check_cuda_error(cublasSetStream(cublas_handle, stream));
cublasAlgoMap cublas_algo_map(GEMM_CONFIG); cublasAlgoMap cublas_algo_map(GEMM_CONFIG);
std::mutex* cublas_wrapper_mutex = new std::mutex(); std::mutex* cublas_wrapper_mutex = new std::mutex();
...@@ -436,7 +436,7 @@ void testGemmConsistencyMatmul(size_t m, size_t n, size_t k) { ...@@ -436,7 +436,7 @@ void testGemmConsistencyMatmul(size_t m, size_t n, size_t k) {
} }
delete cublas_wrapper_mutex; delete cublas_wrapper_mutex;
check_cuda_error(cublasLtDestroy(cublaslt_handle)); // check_cuda_error(cublasLtDestroy(cublaslt_handle));
check_cuda_error(cublasDestroy(cublas_handle)); check_cuda_error(cublasDestroy(cublas_handle));
check_cuda_error(cudaStreamDestroy(stream)); check_cuda_error(cudaStreamDestroy(stream));
} }
...@@ -494,7 +494,7 @@ void testGemmConsistencyBatchedMatmul(size_t m, size_t n, size_t k) { ...@@ -494,7 +494,7 @@ void testGemmConsistencyBatchedMatmul(size_t m, size_t n, size_t k) {
cublasHandle_t cublas_handle; cublasHandle_t cublas_handle;
cublasLtHandle_t cublaslt_handle; cublasLtHandle_t cublaslt_handle;
check_cuda_error(cublasCreate(&cublas_handle)); check_cuda_error(cublasCreate(&cublas_handle));
check_cuda_error(cublasLtCreate(&cublaslt_handle)); // check_cuda_error(cublasLtCreate(&cublaslt_handle));
check_cuda_error(cublasSetStream(cublas_handle, stream)); check_cuda_error(cublasSetStream(cublas_handle, stream));
cublasAlgoMap cublas_algo_map(GEMM_CONFIG); cublasAlgoMap cublas_algo_map(GEMM_CONFIG);
std::mutex* cublas_wrapper_mutex = new std::mutex(); std::mutex* cublas_wrapper_mutex = new std::mutex();
...@@ -569,7 +569,7 @@ void testGemmConsistencyBatchedMatmul(size_t m, size_t n, size_t k) { ...@@ -569,7 +569,7 @@ void testGemmConsistencyBatchedMatmul(size_t m, size_t n, size_t k) {
c_tensors.clear(); c_tensors.clear();
expecteds.clear(); expecteds.clear();
delete cublas_wrapper_mutex; delete cublas_wrapper_mutex;
check_cuda_error(cublasLtDestroy(cublaslt_handle)); // check_cuda_error(cublasLtDestroy(cublaslt_handle));
check_cuda_error(cublasDestroy(cublas_handle)); check_cuda_error(cublasDestroy(cublas_handle));
check_cuda_error(cudaStreamDestroy(stream)); check_cuda_error(cudaStreamDestroy(stream));
} }
...@@ -594,7 +594,7 @@ void testGemmConsistencyStridedBatchedMatmul(size_t batch_size, size_t m, size_t ...@@ -594,7 +594,7 @@ void testGemmConsistencyStridedBatchedMatmul(size_t batch_size, size_t m, size_t
cublasHandle_t cublas_handle; cublasHandle_t cublas_handle;
cublasLtHandle_t cublaslt_handle; cublasLtHandle_t cublaslt_handle;
check_cuda_error(cublasCreate(&cublas_handle)); check_cuda_error(cublasCreate(&cublas_handle));
check_cuda_error(cublasLtCreate(&cublaslt_handle)); // check_cuda_error(cublasLtCreate(&cublaslt_handle));
check_cuda_error(cublasSetStream(cublas_handle, stream)); check_cuda_error(cublasSetStream(cublas_handle, stream));
cublasAlgoMap cublas_algo_map(GEMM_CONFIG); cublasAlgoMap cublas_algo_map(GEMM_CONFIG);
std::mutex* cublas_wrapper_mutex = new std::mutex(); std::mutex* cublas_wrapper_mutex = new std::mutex();
...@@ -683,7 +683,7 @@ void testGemmConsistencyStridedBatchedMatmul(size_t batch_size, size_t m, size_t ...@@ -683,7 +683,7 @@ void testGemmConsistencyStridedBatchedMatmul(size_t batch_size, size_t m, size_t
} }
delete cublas_wrapper_mutex; delete cublas_wrapper_mutex;
check_cuda_error(cublasLtDestroy(cublaslt_handle)); // check_cuda_error(cublasLtDestroy(cublaslt_handle));
check_cuda_error(cublasDestroy(cublas_handle)); check_cuda_error(cublasDestroy(cublas_handle));
check_cuda_error(cudaStreamDestroy(stream)); check_cuda_error(cudaStreamDestroy(stream));
} }
...@@ -779,7 +779,7 @@ void testSpGemmConsistencyMatmul(size_t m, size_t n, size_t k) { ...@@ -779,7 +779,7 @@ void testSpGemmConsistencyMatmul(size_t m, size_t n, size_t k) {
cublasHandle_t cublas_handle; cublasHandle_t cublas_handle;
cublasLtHandle_t cublaslt_handle; cublasLtHandle_t cublaslt_handle;
check_cuda_error(cublasCreate(&cublas_handle)); check_cuda_error(cublasCreate(&cublas_handle));
check_cuda_error(cublasLtCreate(&cublaslt_handle)); // check_cuda_error(cublasLtCreate(&cublaslt_handle));
check_cuda_error(cublasSetStream(cublas_handle, stream)); check_cuda_error(cublasSetStream(cublas_handle, stream));
cublasAlgoMap cublas_algo_map(GEMM_CONFIG); cublasAlgoMap cublas_algo_map(GEMM_CONFIG);
std::mutex* cublas_wrapper_mutex = new std::mutex(); std::mutex* cublas_wrapper_mutex = new std::mutex();
...@@ -844,7 +844,7 @@ void testSpGemmConsistencyMatmul(size_t m, size_t n, size_t k) { ...@@ -844,7 +844,7 @@ void testSpGemmConsistencyMatmul(size_t m, size_t n, size_t k) {
} }
delete cublas_wrapper_mutex; delete cublas_wrapper_mutex;
check_cuda_error(cublasLtDestroy(cublaslt_handle)); // check_cuda_error(cublasLtDestroy(cublaslt_handle));
check_cuda_error(cublasDestroy(cublas_handle)); check_cuda_error(cublasDestroy(cublas_handle));
check_cuda_error(cudaStreamDestroy(stream)); check_cuda_error(cudaStreamDestroy(stream));
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment