Commit e38ee081 authored by xiabo's avatar xiabo
Browse files

Adapt to rocm

parent 56942c43
...@@ -218,8 +218,8 @@ void generate_xlnet_gemm_config(int batch_size, ...@@ -218,8 +218,8 @@ void generate_xlnet_gemm_config(int batch_size,
cublasHandle_t cublas_handle; cublasHandle_t cublas_handle;
check_cuda_error(cublasCreate(&cublas_handle)); check_cuda_error(cublasCreate(&cublas_handle));
cublasLtHandle_t ltHandle; // cublasLtHandle_t ltHandle;
check_cuda_error(cublasLtCreate(&ltHandle)); // check_cuda_error(cublasLtCreate(&ltHandle));
cudaDataType_t AType; cudaDataType_t AType;
cudaDataType_t BType; cudaDataType_t BType;
...@@ -244,8 +244,10 @@ void generate_xlnet_gemm_config(int batch_size, ...@@ -244,8 +244,10 @@ void generate_xlnet_gemm_config(int batch_size,
BType = CUDA_R_16F; BType = CUDA_R_16F;
CType = CUDA_R_16F; CType = CUDA_R_16F;
computeType = CUDA_R_32F; computeType = CUDA_R_32F;
startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; // startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; // endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
startAlgo = (int)CUBLAS_GEMM_DEFAULT;
endAlgo = (int)CUBLAS_GEMM_DEFAULT;
} }
#ifdef ENABLE_BF16 #ifdef ENABLE_BF16
else if (std::is_same<T, __nv_bfloat16>::value) { else if (std::is_same<T, __nv_bfloat16>::value) {
...@@ -254,8 +256,10 @@ void generate_xlnet_gemm_config(int batch_size, ...@@ -254,8 +256,10 @@ void generate_xlnet_gemm_config(int batch_size,
BType = CUDA_R_16BF; BType = CUDA_R_16BF;
CType = CUDA_R_16BF; CType = CUDA_R_16BF;
computeType = CUDA_R_32F; computeType = CUDA_R_32F;
startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; // startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; // endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
startAlgo = (int)CUBLAS_GEMM_DEFAULT;
endAlgo = (int)CUBLAS_GEMM_DEFAULT;
} }
#endif #endif
...@@ -358,30 +362,31 @@ void generate_xlnet_gemm_config(int batch_size, ...@@ -358,30 +362,31 @@ void generate_xlnet_gemm_config(int batch_size,
const int ALGO_COMBINATIONS = 5000; const int ALGO_COMBINATIONS = 5000;
customMatmulPerf_t perfResults[ALGO_COMBINATIONS]; customMatmulPerf_t perfResults[ALGO_COMBINATIONS];
LtHgemmCustomFind<T, scaleT>(ltHandle, // LtHgemmCustomFind<T, scaleT>(ltHandle,
batch_size, // batch_size,
seq_len, // seq_len,
head_num, // head_num,
size_per_head, // size_per_head,
n, // n,
m, // m,
k, // k,
&alpha, // &alpha,
d_B, // d_B,
d_A, // d_A,
&beta, // &beta,
d_C, // d_C,
cublas_workspace, // cublas_workspace,
workSpaceSize, // workSpaceSize,
fd, // fd,
perfResults, // perfResults,
ALGO_COMBINATIONS); // ALGO_COMBINATIONS);
if (perfResults[0].time < exec_time) { // if (perfResults[0].time < exec_time) {
printPerfStructure( // printPerfStructure(
batch_size, seq_len, head_num, size_per_head, n, m, k, perfResults[0], fd, data_type, 0); // batch_size, seq_len, head_num, size_per_head, n, m, k, perfResults[0], fd, data_type, 0);
exec_time = perfResults[0].time; // exec_time = perfResults[0].time;
} // }
else { // else {
{
fprintf(fd, fprintf(fd,
"%d %d %d %d %d ### %d %d %d %d %d -1 -1 -1 -1 -1 -1 -1 " "%d %d %d %d %d ### %d %d %d %d %d -1 -1 -1 -1 -1 -1 -1 "
#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3) #if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment