"...dcu-process-montor.git" did not exist on "d40a98f31aaeed17d82317eba2565a0bfbb9b196"
Commit e38ee081 authored by xiabo's avatar xiabo
Browse files

Adapt to rocm

parent 56942c43
......@@ -218,8 +218,8 @@ void generate_xlnet_gemm_config(int batch_size,
cublasHandle_t cublas_handle;
check_cuda_error(cublasCreate(&cublas_handle));
cublasLtHandle_t ltHandle;
check_cuda_error(cublasLtCreate(&ltHandle));
// cublasLtHandle_t ltHandle;
// check_cuda_error(cublasLtCreate(&ltHandle));
cudaDataType_t AType;
cudaDataType_t BType;
......@@ -244,8 +244,10 @@ void generate_xlnet_gemm_config(int batch_size,
BType = CUDA_R_16F;
CType = CUDA_R_16F;
computeType = CUDA_R_32F;
startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
// startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
// endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
startAlgo = (int)CUBLAS_GEMM_DEFAULT;
endAlgo = (int)CUBLAS_GEMM_DEFAULT;
}
#ifdef ENABLE_BF16
else if (std::is_same<T, __nv_bfloat16>::value) {
......@@ -254,8 +256,10 @@ void generate_xlnet_gemm_config(int batch_size,
BType = CUDA_R_16BF;
CType = CUDA_R_16BF;
computeType = CUDA_R_32F;
startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
// startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
// endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
startAlgo = (int)CUBLAS_GEMM_DEFAULT;
endAlgo = (int)CUBLAS_GEMM_DEFAULT;
}
#endif
......@@ -358,30 +362,31 @@ void generate_xlnet_gemm_config(int batch_size,
const int ALGO_COMBINATIONS = 5000;
customMatmulPerf_t perfResults[ALGO_COMBINATIONS];
LtHgemmCustomFind<T, scaleT>(ltHandle,
batch_size,
seq_len,
head_num,
size_per_head,
n,
m,
k,
&alpha,
d_B,
d_A,
&beta,
d_C,
cublas_workspace,
workSpaceSize,
fd,
perfResults,
ALGO_COMBINATIONS);
if (perfResults[0].time < exec_time) {
printPerfStructure(
batch_size, seq_len, head_num, size_per_head, n, m, k, perfResults[0], fd, data_type, 0);
exec_time = perfResults[0].time;
}
else {
// LtHgemmCustomFind<T, scaleT>(ltHandle,
// batch_size,
// seq_len,
// head_num,
// size_per_head,
// n,
// m,
// k,
// &alpha,
// d_B,
// d_A,
// &beta,
// d_C,
// cublas_workspace,
// workSpaceSize,
// fd,
// perfResults,
// ALGO_COMBINATIONS);
// if (perfResults[0].time < exec_time) {
// printPerfStructure(
// batch_size, seq_len, head_num, size_per_head, n, m, k, perfResults[0], fd, data_type, 0);
// exec_time = perfResults[0].time;
// }
// else {
{
fprintf(fd,
"%d %d %d %d %d ### %d %d %d %d %d -1 -1 -1 -1 -1 -1 -1 "
#if (CUBLAS_VER_MAJOR == 11 && CUBLAS_VER_MINOR == 11 && CUBLAS_VER_PATCH >= 3)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment