Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
LLama_fastertransformer
Commits
a929d1c6
Commit
a929d1c6
authored
Sep 05, 2023
by
zhuwenwen
Browse files
support unittests
parent
d42788f0
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
17 additions
and
12 deletions
+17
-12
src/fastertransformer/utils/gemm.cc
src/fastertransformer/utils/gemm.cc
+4
-3
tests/unittests/CMakeLists.txt
tests/unittests/CMakeLists.txt
+4
-0
tests/unittests/test_context_decoder_layer.cu
tests/unittests/test_context_decoder_layer.cu
+1
-1
tests/unittests/test_gemm.cu
tests/unittests/test_gemm.cu
+8
-8
No files found.
src/fastertransformer/utils/gemm.cc
View file @
a929d1c6
...
...
@@ -26,7 +26,7 @@ Gemm::Gemm(IAllocator* allocator, cudaStream_t stream, std::string config_file)
stream_
=
stream
;
mutex_
=
new
std
::
mutex
();
// mutex per process
check_cuda_error
(
cublasCreate
(
&
cublas_handle_
));
check_cuda_error
(
cublasLtCreate
(
&
cublaslt_handle_
));
//
check_cuda_error(cublasLtCreate(&cublaslt_handle_));
check_cuda_error
(
cublasSetStream
(
cublas_handle_
,
stream
));
if
(
allocator_
!=
nullptr
)
{
...
...
@@ -41,7 +41,7 @@ Gemm::~Gemm()
allocator_
->
free
((
void
**
)(
&
workspace_
));
allocator_
=
nullptr
;
}
cublasLtDestroy
(
cublaslt_handle_
);
//
cublasLtDestroy(cublaslt_handle_);
cublasDestroy
(
cublas_handle_
);
delete
cublas_algo_map_
;
delete
mutex_
;
...
...
@@ -248,7 +248,8 @@ void Gemm::gemm(const GemmOp transa,
mutex_
->
lock
();
// Use cublas as default in FP32 and cublasLt as default in FP16
bool
is_fp16_compute_type
=
compute_type_
==
TYPE_FP16
;
bool
using_cublasLt
=
Atype
==
TYPE_FP16
;
// bool using_cublasLt = Atype == TYPE_FP16;
bool
using_cublasLt
=
(
Atype
==
TYPE_FP16
)
?
false
:
false
;
int
batch_count
=
1
;
half
h_alpha
=
(
half
)
alpha
;
...
...
tests/unittests/CMakeLists.txt
View file @
a929d1c6
...
...
@@ -19,6 +19,10 @@ FetchContent_Declare(
googletest
GIT_REPOSITORY https://github.com/google/googletest.git
GIT_TAG release-1.12.1
# URL /path/to/local/googletest-release-1.12.1.zip
# URL_HASH SHA256=24564e3b712d3eb30ac9a85d92f7d720f60cc0173730ac166f27dda7fed76cb2
# export C_INCLUDE_PATH=/path/to/local/googletest-release-1.12.1/googletest/include${C_INCLUDE_PATH:+:${C_INCLUDE_PATH}}
# export CPLUS_INCLUDE_PATH=/path/to/local/googletest-release-1.12.1/include${CPLUS_INCLUDE_PATH:+:${CPLUS_INCLUDE_PATH}}
)
add_definitions
(
-DTORCH_CUDA=1
)
...
...
tests/unittests/test_context_decoder_layer.cu
View file @
a929d1c6
...
...
@@ -74,7 +74,7 @@ bool test_context_sharing(const std::string& weight_dir, const std::string& data
cublasLtHandle_t
cublaslt_handle
;
check_cuda_error
(
cudaStreamCreate
(
&
stream
));
check_cuda_error
(
cublasCreate
(
&
cublas_handle
));
check_cuda_error
(
cublasLtCreate
(
&
cublaslt_handle
));
//
check_cuda_error(cublasLtCreate(&cublaslt_handle));
check_cuda_error
(
cublasSetStream
(
cublas_handle
,
stream
));
cublasAlgoMap
cublas_algo_map
(
GEMM_CONFIG
);
...
...
tests/unittests/test_gemm.cu
View file @
a929d1c6
...
...
@@ -378,7 +378,7 @@ void testGemmConsistencyMatmul(size_t m, size_t n, size_t k) {
cublasHandle_t
cublas_handle
;
cublasLtHandle_t
cublaslt_handle
;
check_cuda_error
(
cublasCreate
(
&
cublas_handle
));
check_cuda_error
(
cublasLtCreate
(
&
cublaslt_handle
));
//
check_cuda_error(cublasLtCreate(&cublaslt_handle));
check_cuda_error
(
cublasSetStream
(
cublas_handle
,
stream
));
cublasAlgoMap
cublas_algo_map
(
GEMM_CONFIG
);
std
::
mutex
*
cublas_wrapper_mutex
=
new
std
::
mutex
();
...
...
@@ -436,7 +436,7 @@ void testGemmConsistencyMatmul(size_t m, size_t n, size_t k) {
}
delete
cublas_wrapper_mutex
;
check_cuda_error
(
cublasLtDestroy
(
cublaslt_handle
));
//
check_cuda_error(cublasLtDestroy(cublaslt_handle));
check_cuda_error
(
cublasDestroy
(
cublas_handle
));
check_cuda_error
(
cudaStreamDestroy
(
stream
));
}
...
...
@@ -494,7 +494,7 @@ void testGemmConsistencyBatchedMatmul(size_t m, size_t n, size_t k) {
cublasHandle_t
cublas_handle
;
cublasLtHandle_t
cublaslt_handle
;
check_cuda_error
(
cublasCreate
(
&
cublas_handle
));
check_cuda_error
(
cublasLtCreate
(
&
cublaslt_handle
));
//
check_cuda_error(cublasLtCreate(&cublaslt_handle));
check_cuda_error
(
cublasSetStream
(
cublas_handle
,
stream
));
cublasAlgoMap
cublas_algo_map
(
GEMM_CONFIG
);
std
::
mutex
*
cublas_wrapper_mutex
=
new
std
::
mutex
();
...
...
@@ -569,7 +569,7 @@ void testGemmConsistencyBatchedMatmul(size_t m, size_t n, size_t k) {
c_tensors
.
clear
();
expecteds
.
clear
();
delete
cublas_wrapper_mutex
;
check_cuda_error
(
cublasLtDestroy
(
cublaslt_handle
));
//
check_cuda_error(cublasLtDestroy(cublaslt_handle));
check_cuda_error
(
cublasDestroy
(
cublas_handle
));
check_cuda_error
(
cudaStreamDestroy
(
stream
));
}
...
...
@@ -594,7 +594,7 @@ void testGemmConsistencyStridedBatchedMatmul(size_t batch_size, size_t m, size_t
cublasHandle_t
cublas_handle
;
cublasLtHandle_t
cublaslt_handle
;
check_cuda_error
(
cublasCreate
(
&
cublas_handle
));
check_cuda_error
(
cublasLtCreate
(
&
cublaslt_handle
));
//
check_cuda_error(cublasLtCreate(&cublaslt_handle));
check_cuda_error
(
cublasSetStream
(
cublas_handle
,
stream
));
cublasAlgoMap
cublas_algo_map
(
GEMM_CONFIG
);
std
::
mutex
*
cublas_wrapper_mutex
=
new
std
::
mutex
();
...
...
@@ -683,7 +683,7 @@ void testGemmConsistencyStridedBatchedMatmul(size_t batch_size, size_t m, size_t
}
delete
cublas_wrapper_mutex
;
check_cuda_error
(
cublasLtDestroy
(
cublaslt_handle
));
//
check_cuda_error(cublasLtDestroy(cublaslt_handle));
check_cuda_error
(
cublasDestroy
(
cublas_handle
));
check_cuda_error
(
cudaStreamDestroy
(
stream
));
}
...
...
@@ -779,7 +779,7 @@ void testSpGemmConsistencyMatmul(size_t m, size_t n, size_t k) {
cublasHandle_t
cublas_handle
;
cublasLtHandle_t
cublaslt_handle
;
check_cuda_error
(
cublasCreate
(
&
cublas_handle
));
check_cuda_error
(
cublasLtCreate
(
&
cublaslt_handle
));
//
check_cuda_error(cublasLtCreate(&cublaslt_handle));
check_cuda_error
(
cublasSetStream
(
cublas_handle
,
stream
));
cublasAlgoMap
cublas_algo_map
(
GEMM_CONFIG
);
std
::
mutex
*
cublas_wrapper_mutex
=
new
std
::
mutex
();
...
...
@@ -844,7 +844,7 @@ void testSpGemmConsistencyMatmul(size_t m, size_t n, size_t k) {
}
delete
cublas_wrapper_mutex
;
check_cuda_error
(
cublasLtDestroy
(
cublaslt_handle
));
//
check_cuda_error(cublasLtDestroy(cublaslt_handle));
check_cuda_error
(
cublasDestroy
(
cublas_handle
));
check_cuda_error
(
cudaStreamDestroy
(
stream
));
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment