Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Lmdeploy
Commits
e38ee081
"...dcu-process-montor.git" did not exist on "d40a98f31aaeed17d82317eba2565a0bfbb9b196"
Commit
e38ee081
authored
Nov 14, 2023
by
xiabo
Browse files
Adapt to rocm
parent
56942c43
Changes
41
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
35 additions
and
30 deletions
+35
-30
src/turbomind/utils/gemm_test/xlnet_gemm_func.cc
src/turbomind/utils/gemm_test/xlnet_gemm_func.cc
+35
-30
No files found.
src/turbomind/utils/gemm_test/xlnet_gemm_func.cc
View file @
e38ee081
...
...
@@ -218,8 +218,8 @@ void generate_xlnet_gemm_config(int batch_size,
cublasHandle_t
cublas_handle
;
check_cuda_error
(
cublasCreate
(
&
cublas_handle
));
cublasLtHandle_t
ltHandle
;
check_cuda_error
(
cublasLtCreate
(
&
ltHandle
));
//
cublasLtHandle_t ltHandle;
//
check_cuda_error(cublasLtCreate(<Handle));
cudaDataType_t
AType
;
cudaDataType_t
BType
;
...
...
@@ -244,8 +244,10 @@ void generate_xlnet_gemm_config(int batch_size,
BType
=
CUDA_R_16F
;
CType
=
CUDA_R_16F
;
computeType
=
CUDA_R_32F
;
startAlgo
=
(
int
)
CUBLAS_GEMM_DEFAULT_TENSOR_OP
;
endAlgo
=
(
int
)
CUBLAS_GEMM_ALGO15_TENSOR_OP
;
// startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
// endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
startAlgo
=
(
int
)
CUBLAS_GEMM_DEFAULT
;
endAlgo
=
(
int
)
CUBLAS_GEMM_DEFAULT
;
}
#ifdef ENABLE_BF16
else
if
(
std
::
is_same
<
T
,
__nv_bfloat16
>::
value
)
{
...
...
@@ -254,8 +256,10 @@ void generate_xlnet_gemm_config(int batch_size,
BType
=
CUDA_R_16BF
;
CType
=
CUDA_R_16BF
;
computeType
=
CUDA_R_32F
;
startAlgo
=
(
int
)
CUBLAS_GEMM_DEFAULT_TENSOR_OP
;
endAlgo
=
(
int
)
CUBLAS_GEMM_ALGO15_TENSOR_OP
;
// startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
// endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
startAlgo
=
(
int
)
CUBLAS_GEMM_DEFAULT
;
endAlgo
=
(
int
)
CUBLAS_GEMM_DEFAULT
;
}
#endif
...
...
@@ -358,30 +362,31 @@ void generate_xlnet_gemm_config(int batch_size,
const
int
ALGO_COMBINATIONS
=
5000
;
customMatmulPerf_t
perfResults
[
ALGO_COMBINATIONS
];
LtHgemmCustomFind
<
T
,
scaleT
>
(
ltHandle
,
batch_size
,
seq_len
,
head_num
,
size_per_head
,
n
,
m
,
k
,
&
alpha
,
d_B
,
d_A
,
&
beta
,
d_C
,
cublas_workspace
,
workSpaceSize
,
fd
,
perfResults
,
ALGO_COMBINATIONS
);
if
(
perfResults
[
0
].
time
<
exec_time
)
{
printPerfStructure
(
batch_size
,
seq_len
,
head_num
,
size_per_head
,
n
,
m
,
k
,
perfResults
[
0
],
fd
,
data_type
,
0
);
exec_time
=
perfResults
[
0
].
time
;
}
else
{
// LtHgemmCustomFind<T, scaleT>(ltHandle,
// batch_size,
// seq_len,
// head_num,
// size_per_head,
// n,
// m,
// k,
// &alpha,
// d_B,
// d_A,
// &beta,
// d_C,
// cublas_workspace,
// workSpaceSize,
// fd,
// perfResults,
// ALGO_COMBINATIONS);
// if (perfResults[0].time < exec_time) {
// printPerfStructure(
// batch_size, seq_len, head_num, size_per_head, n, m, k, perfResults[0], fd, data_type, 0);
// exec_time = perfResults[0].time;
// }
// else {
{
fprintf
(
fd
,
"%d %d %d %d %d ### %d %d %d %d %d -1 -1 -1 -1 -1 -1 -1 "
#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment