Commit a929d1c6 authored by zhuwenwen's avatar zhuwenwen
Browse files

support unittests

parent d42788f0
......@@ -26,7 +26,7 @@ Gemm::Gemm(IAllocator* allocator, cudaStream_t stream, std::string config_file)
stream_ = stream;
mutex_ = new std::mutex(); // mutex per process
check_cuda_error(cublasCreate(&cublas_handle_));
check_cuda_error(cublasLtCreate(&cublaslt_handle_));
// check_cuda_error(cublasLtCreate(&cublaslt_handle_));
check_cuda_error(cublasSetStream(cublas_handle_, stream));
if (allocator_ != nullptr) {
......@@ -41,7 +41,7 @@ Gemm::~Gemm()
allocator_->free((void**)(&workspace_));
allocator_ = nullptr;
}
cublasLtDestroy(cublaslt_handle_);
// cublasLtDestroy(cublaslt_handle_);
cublasDestroy(cublas_handle_);
delete cublas_algo_map_;
delete mutex_;
......@@ -248,7 +248,8 @@ void Gemm::gemm(const GemmOp transa,
mutex_->lock();
// Use cublas as default in FP32 and cublasLt as default in FP16
bool is_fp16_compute_type = compute_type_ == TYPE_FP16;
bool using_cublasLt = Atype == TYPE_FP16;
// bool using_cublasLt = Atype == TYPE_FP16;
bool using_cublasLt = (Atype == TYPE_FP16) ? false : false;
int batch_count = 1;
half h_alpha = (half)alpha;
......
......@@ -19,6 +19,10 @@ FetchContent_Declare(
googletest
GIT_REPOSITORY https://github.com/google/googletest.git
GIT_TAG release-1.12.1
# URL /path/to/local/googletest-release-1.12.1.zip
# URL_HASH SHA256=24564e3b712d3eb30ac9a85d92f7d720f60cc0173730ac166f27dda7fed76cb2
# export C_INCLUDE_PATH=/path/to/local/googletest-release-1.12.1/googletest/include${C_INCLUDE_PATH:+:${C_INCLUDE_PATH}}
# export CPLUS_INCLUDE_PATH=/path/to/local/googletest-release-1.12.1/include${CPLUS_INCLUDE_PATH:+:${CPLUS_INCLUDE_PATH}}
)
add_definitions(-DTORCH_CUDA=1)
......
......@@ -74,7 +74,7 @@ bool test_context_sharing(const std::string& weight_dir, const std::string& data
cublasLtHandle_t cublaslt_handle;
check_cuda_error(cudaStreamCreate(&stream));
check_cuda_error(cublasCreate(&cublas_handle));
check_cuda_error(cublasLtCreate(&cublaslt_handle));
// check_cuda_error(cublasLtCreate(&cublaslt_handle));
check_cuda_error(cublasSetStream(cublas_handle, stream));
cublasAlgoMap cublas_algo_map(GEMM_CONFIG);
......
......@@ -378,7 +378,7 @@ void testGemmConsistencyMatmul(size_t m, size_t n, size_t k) {
cublasHandle_t cublas_handle;
cublasLtHandle_t cublaslt_handle;
check_cuda_error(cublasCreate(&cublas_handle));
check_cuda_error(cublasLtCreate(&cublaslt_handle));
// check_cuda_error(cublasLtCreate(&cublaslt_handle));
check_cuda_error(cublasSetStream(cublas_handle, stream));
cublasAlgoMap cublas_algo_map(GEMM_CONFIG);
std::mutex* cublas_wrapper_mutex = new std::mutex();
......@@ -436,7 +436,7 @@ void testGemmConsistencyMatmul(size_t m, size_t n, size_t k) {
}
delete cublas_wrapper_mutex;
check_cuda_error(cublasLtDestroy(cublaslt_handle));
// check_cuda_error(cublasLtDestroy(cublaslt_handle));
check_cuda_error(cublasDestroy(cublas_handle));
check_cuda_error(cudaStreamDestroy(stream));
}
......@@ -494,7 +494,7 @@ void testGemmConsistencyBatchedMatmul(size_t m, size_t n, size_t k) {
cublasHandle_t cublas_handle;
cublasLtHandle_t cublaslt_handle;
check_cuda_error(cublasCreate(&cublas_handle));
check_cuda_error(cublasLtCreate(&cublaslt_handle));
// check_cuda_error(cublasLtCreate(&cublaslt_handle));
check_cuda_error(cublasSetStream(cublas_handle, stream));
cublasAlgoMap cublas_algo_map(GEMM_CONFIG);
std::mutex* cublas_wrapper_mutex = new std::mutex();
......@@ -569,7 +569,7 @@ void testGemmConsistencyBatchedMatmul(size_t m, size_t n, size_t k) {
c_tensors.clear();
expecteds.clear();
delete cublas_wrapper_mutex;
check_cuda_error(cublasLtDestroy(cublaslt_handle));
// check_cuda_error(cublasLtDestroy(cublaslt_handle));
check_cuda_error(cublasDestroy(cublas_handle));
check_cuda_error(cudaStreamDestroy(stream));
}
......@@ -594,7 +594,7 @@ void testGemmConsistencyStridedBatchedMatmul(size_t batch_size, size_t m, size_t
cublasHandle_t cublas_handle;
cublasLtHandle_t cublaslt_handle;
check_cuda_error(cublasCreate(&cublas_handle));
check_cuda_error(cublasLtCreate(&cublaslt_handle));
// check_cuda_error(cublasLtCreate(&cublaslt_handle));
check_cuda_error(cublasSetStream(cublas_handle, stream));
cublasAlgoMap cublas_algo_map(GEMM_CONFIG);
std::mutex* cublas_wrapper_mutex = new std::mutex();
......@@ -683,7 +683,7 @@ void testGemmConsistencyStridedBatchedMatmul(size_t batch_size, size_t m, size_t
}
delete cublas_wrapper_mutex;
check_cuda_error(cublasLtDestroy(cublaslt_handle));
// check_cuda_error(cublasLtDestroy(cublaslt_handle));
check_cuda_error(cublasDestroy(cublas_handle));
check_cuda_error(cudaStreamDestroy(stream));
}
......@@ -779,7 +779,7 @@ void testSpGemmConsistencyMatmul(size_t m, size_t n, size_t k) {
cublasHandle_t cublas_handle;
cublasLtHandle_t cublaslt_handle;
check_cuda_error(cublasCreate(&cublas_handle));
check_cuda_error(cublasLtCreate(&cublaslt_handle));
// check_cuda_error(cublasLtCreate(&cublaslt_handle));
check_cuda_error(cublasSetStream(cublas_handle, stream));
cublasAlgoMap cublas_algo_map(GEMM_CONFIG);
std::mutex* cublas_wrapper_mutex = new std::mutex();
......@@ -844,7 +844,7 @@ void testSpGemmConsistencyMatmul(size_t m, size_t n, size_t k) {
}
delete cublas_wrapper_mutex;
check_cuda_error(cublasLtDestroy(cublaslt_handle));
// check_cuda_error(cublasLtDestroy(cublaslt_handle));
check_cuda_error(cublasDestroy(cublas_handle));
check_cuda_error(cudaStreamDestroy(stream));
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment