Commit ab8c95cb authored by xiabo's avatar xiabo
Browse files

Adapt to rocm FT的修改补充

parent 6939e47c
...@@ -185,7 +185,7 @@ void cublasMMWrapper::Gemm(cublasOperation_t transa, ...@@ -185,7 +185,7 @@ void cublasMMWrapper::Gemm(cublasOperation_t transa,
cublasLtMatmulAlgo_info info = cublas_algo_map_->getAlgo(batch_count, m, n, k, getCublasDataType(Atype_)); cublasLtMatmulAlgo_info info = cublas_algo_map_->getAlgo(batch_count, m, n, k, getCublasDataType(Atype_));
if (findAlgo) { if (findAlgo) {
if (info.stages != -1) { if (info.stages != -1) {
using_cublasLt = true; using_cublasLt = false;
} }
else { else {
using_cublasLt = false; using_cublasLt = false;
...@@ -342,7 +342,7 @@ void cublasMMWrapper::setFP16GemmConfig() ...@@ -342,7 +342,7 @@ void cublasMMWrapper::setFP16GemmConfig()
Atype_ = CUDA_R_16F; Atype_ = CUDA_R_16F;
Btype_ = CUDA_R_16F; Btype_ = CUDA_R_16F;
Ctype_ = CUDA_R_16F; Ctype_ = CUDA_R_16F;
computeType_ = CUDA_R_32F; computeType_ = CUDA_R_16F;
} }
#ifdef ENABLE_BF16 #ifdef ENABLE_BF16
......
...@@ -148,14 +148,15 @@ void generate_decoding_gemm_config(int batch_size, ...@@ -148,14 +148,15 @@ void generate_decoding_gemm_config(int batch_size,
CType = CUDA_R_32F; CType = CUDA_R_32F;
computeType = CUDA_R_32F; computeType = CUDA_R_32F;
startAlgo = (int)CUBLAS_GEMM_DEFAULT; startAlgo = (int)CUBLAS_GEMM_DEFAULT;
endAlgo = (int)CUBLAS_GEMM_ALGO23; // endAlgo = (int)CUBLAS_GEMM_ALGO23;
endAlgo = (int)CUBLAS_GEMM_DEFAULT;
} }
else if (std::is_same<T, half>::value) { else if (std::is_same<T, half>::value) {
data_type = HALF_DATATYPE; data_type = HALF_DATATYPE;
AType = CUDA_R_16F; AType = CUDA_R_16F;
BType = CUDA_R_16F; BType = CUDA_R_16F;
CType = CUDA_R_16F; CType = CUDA_R_16F;
computeType = CUDA_R_32F; computeType = CUDA_R_16F;
// startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; // startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
// endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; // endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
startAlgo = (int)CUBLAS_GEMM_DEFAULT; startAlgo = (int)CUBLAS_GEMM_DEFAULT;
...@@ -174,7 +175,8 @@ void generate_decoding_gemm_config(int batch_size, ...@@ -174,7 +175,8 @@ void generate_decoding_gemm_config(int batch_size,
endAlgo = (int)CUBLAS_GEMM_DEFAULT; endAlgo = (int)CUBLAS_GEMM_DEFAULT;
} }
#endif #endif
using scaleT = typename ScaleTypeConverter<T>::Type; // using scaleT = typename ScaleTypeConverter<T>::Type;
using scaleT = typename ScaleTypeConverter<T, true>::Type;
scaleT alpha = (scaleT)1.0f; scaleT alpha = (scaleT)1.0f;
scaleT beta = (scaleT)0.0f; scaleT beta = (scaleT)0.0f;
......
...@@ -145,14 +145,15 @@ void generate_encoder_gemm_config( ...@@ -145,14 +145,15 @@ void generate_encoder_gemm_config(
CType = CUDA_R_32F; CType = CUDA_R_32F;
computeType = CUDA_R_32F; computeType = CUDA_R_32F;
startAlgo = (int)CUBLAS_GEMM_DEFAULT; startAlgo = (int)CUBLAS_GEMM_DEFAULT;
endAlgo = (int)CUBLAS_GEMM_ALGO23; // endAlgo = (int)CUBLAS_GEMM_ALGO23;
endAlgo = (int)CUBLAS_GEMM_DEFAULT;
} }
else if (std::is_same<T, half>::value) { else if (std::is_same<T, half>::value) {
data_type = HALF_DATATYPE; data_type = HALF_DATATYPE;
AType = CUDA_R_16F; AType = CUDA_R_16F;
BType = CUDA_R_16F; BType = CUDA_R_16F;
CType = CUDA_R_16F; CType = CUDA_R_16F;
computeType = CUDA_R_32F; computeType = CUDA_R_16F;
// startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; // startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
// endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; // endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
startAlgo = (int)CUBLAS_GEMM_DEFAULT; startAlgo = (int)CUBLAS_GEMM_DEFAULT;
...@@ -171,7 +172,8 @@ void generate_encoder_gemm_config( ...@@ -171,7 +172,8 @@ void generate_encoder_gemm_config(
endAlgo = (int)CUBLAS_GEMM_DEFAULT; endAlgo = (int)CUBLAS_GEMM_DEFAULT;
} }
#endif #endif
using scaleT = typename ScaleTypeConverter<T, false>::Type; // using scaleT = typename ScaleTypeConverter<T, false>::Type;
using scaleT = typename ScaleTypeConverter<T, true>::Type;
scaleT alpha = (scaleT)1.0f; scaleT alpha = (scaleT)1.0f;
scaleT beta = (scaleT)0.0f; scaleT beta = (scaleT)0.0f;
......
...@@ -244,7 +244,8 @@ void generate_gpt_gemm_config(int batch_size, ...@@ -244,7 +244,8 @@ void generate_gpt_gemm_config(int batch_size,
DType = CUDA_R_32F; DType = CUDA_R_32F;
computeType = CUDA_R_32F; computeType = CUDA_R_32F;
startAlgo = (int)CUBLAS_GEMM_DEFAULT; startAlgo = (int)CUBLAS_GEMM_DEFAULT;
endAlgo = (int)CUBLAS_GEMM_ALGO23; // endAlgo = (int)CUBLAS_GEMM_ALGO23;
endAlgo = (int)CUBLAS_GEMM_DEFAULT;
} }
else if (std::is_same<T, half>::value) { else if (std::is_same<T, half>::value) {
data_type = HALF_DATATYPE; data_type = HALF_DATATYPE;
...@@ -252,7 +253,7 @@ void generate_gpt_gemm_config(int batch_size, ...@@ -252,7 +253,7 @@ void generate_gpt_gemm_config(int batch_size,
BType = CUDA_R_16F; BType = CUDA_R_16F;
CType = CUDA_R_16F; CType = CUDA_R_16F;
DType = CUDA_R_16F; DType = CUDA_R_16F;
computeType = CUDA_R_32F; computeType = CUDA_R_16F;
// startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; // startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
// endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; // endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
startAlgo = (int)CUBLAS_GEMM_DEFAULT; startAlgo = (int)CUBLAS_GEMM_DEFAULT;
...@@ -303,8 +304,18 @@ void generate_gpt_gemm_config(int batch_size, ...@@ -303,8 +304,18 @@ void generate_gpt_gemm_config(int batch_size,
endAlgo = (int)CUBLAS_GEMM_DEFAULT; endAlgo = (int)CUBLAS_GEMM_DEFAULT;
} }
#endif #endif
float alpha = (float)1.0f; // float alpha = (float)1.0f;
float beta = (float)0.0f; // float beta = (float)0.0f;
float f_alpha = (float)1.0f;
float f_beta = (float)0.0f;
half h_alpha = (half)(f_alpha);
half h_beta = (half)(f_beta);
int is_fp16_computeType = computeType == CUDA_R_16F ? 1 : 0;
const void* alpha = is_fp16_computeType ? reinterpret_cast<void*>(&h_alpha) : reinterpret_cast<void*>(&f_alpha);
const void* beta = is_fp16_computeType ? reinterpret_cast<void*>(&h_beta) : reinterpret_cast<void*>(&f_beta);
printf("***Encoder Gemm Testing Begin***\n"); printf("***Encoder Gemm Testing Begin***\n");
printf("***Cublas Gemm Testing Begin***\n"); printf("***Cublas Gemm Testing Begin***\n");
...@@ -348,7 +359,7 @@ void generate_gpt_gemm_config(int batch_size, ...@@ -348,7 +359,7 @@ void generate_gpt_gemm_config(int batch_size,
max_input_len, max_input_len,
max_input_len, max_input_len,
size_per_head, size_per_head,
&alpha, &f_alpha,
d_B, d_B,
BType, BType,
size_per_head, size_per_head,
...@@ -357,13 +368,13 @@ void generate_gpt_gemm_config(int batch_size, ...@@ -357,13 +368,13 @@ void generate_gpt_gemm_config(int batch_size,
AType, AType,
size_per_head, size_per_head,
max_input_len * size_per_head, max_input_len * size_per_head,
&beta, &f_beta,
d_C, d_C,
CUDA_R_32F, // CType, CUDA_R_32F, // CType,
max_input_len, max_input_len,
max_input_len * max_input_len, max_input_len * max_input_len,
batchCount[i], batchCount[i],
computeType, CUDA_R_32F,
static_cast<cublasGemmAlgo_t>(algo)); static_cast<cublasGemmAlgo_t>(algo));
} }
else if (i == 2) { else if (i == 2) {
......
...@@ -151,14 +151,15 @@ void generate_swin_gemm_config( ...@@ -151,14 +151,15 @@ void generate_swin_gemm_config(
CType = CUDA_R_32F; CType = CUDA_R_32F;
computeType = CUDA_R_32F; computeType = CUDA_R_32F;
startAlgo = (int)CUBLAS_GEMM_DEFAULT; startAlgo = (int)CUBLAS_GEMM_DEFAULT;
endAlgo = (int)CUBLAS_GEMM_ALGO23; // endAlgo = (int)CUBLAS_GEMM_ALGO23;
endAlgo = (int)CUBLAS_GEMM_DEFAULT;
} }
else if (std::is_same<T, half>::value) { else if (std::is_same<T, half>::value) {
data_type = HALF_DATATYPE; data_type = HALF_DATATYPE;
AType = CUDA_R_16F; AType = CUDA_R_16F;
BType = CUDA_R_16F; BType = CUDA_R_16F;
CType = CUDA_R_16F; CType = CUDA_R_16F;
computeType = CUDA_R_32F; computeType = CUDA_R_16F;
// startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; // startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
// endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; // endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
startAlgo = (int)CUBLAS_GEMM_DEFAULT; startAlgo = (int)CUBLAS_GEMM_DEFAULT;
...@@ -177,7 +178,8 @@ void generate_swin_gemm_config( ...@@ -177,7 +178,8 @@ void generate_swin_gemm_config(
endAlgo = (int)CUBLAS_GEMM_DEFAULT; endAlgo = (int)CUBLAS_GEMM_DEFAULT;
} }
#endif #endif
using scaleT = typename ScaleTypeConverter<T, false>::Type; // using scaleT = typename ScaleTypeConverter<T, false>::Type;
using scaleT = typename ScaleTypeConverter<T, true>::Type;
scaleT alpha = (scaleT)1.0f; scaleT alpha = (scaleT)1.0f;
scaleT beta = (scaleT)0.0f; scaleT beta = (scaleT)0.0f;
......
...@@ -213,14 +213,15 @@ void generate_t5_gemm_config(int batch_size, ...@@ -213,14 +213,15 @@ void generate_t5_gemm_config(int batch_size,
CType = CUDA_R_32F; CType = CUDA_R_32F;
computeType = CUDA_R_32F; computeType = CUDA_R_32F;
startAlgo = (int)CUBLAS_GEMM_DEFAULT; startAlgo = (int)CUBLAS_GEMM_DEFAULT;
endAlgo = (int)CUBLAS_GEMM_ALGO23; // endAlgo = (int)CUBLAS_GEMM_ALGO23;
endAlgo = (int)CUBLAS_GEMM_DEFAULT;
} }
else if (std::is_same<T, half>::value) { else if (std::is_same<T, half>::value) {
data_type = HALF_DATATYPE; data_type = HALF_DATATYPE;
AType = CUDA_R_16F; AType = CUDA_R_16F;
BType = CUDA_R_16F; BType = CUDA_R_16F;
CType = CUDA_R_16F; CType = CUDA_R_16F;
computeType = CUDA_R_32F; computeType = CUDA_R_16F;
// startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; // startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
// endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; // endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
startAlgo = (int)CUBLAS_GEMM_DEFAULT; startAlgo = (int)CUBLAS_GEMM_DEFAULT;
......
...@@ -236,14 +236,15 @@ void generate_xlnet_gemm_config(int batch_size, ...@@ -236,14 +236,15 @@ void generate_xlnet_gemm_config(int batch_size,
CType = CUDA_R_32F; CType = CUDA_R_32F;
computeType = CUDA_R_32F; computeType = CUDA_R_32F;
startAlgo = (int)CUBLAS_GEMM_DEFAULT; startAlgo = (int)CUBLAS_GEMM_DEFAULT;
endAlgo = (int)CUBLAS_GEMM_ALGO23; // endAlgo = (int)CUBLAS_GEMM_ALGO23;
endAlgo = (int)CUBLAS_GEMM_DEFAULT;
} }
else if (std::is_same<T, half>::value) { else if (std::is_same<T, half>::value) {
data_type = HALF_DATATYPE; data_type = HALF_DATATYPE;
AType = CUDA_R_16F; AType = CUDA_R_16F;
BType = CUDA_R_16F; BType = CUDA_R_16F;
CType = CUDA_R_16F; CType = CUDA_R_16F;
computeType = CUDA_R_32F; computeType = CUDA_R_16F;
// startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; // startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP;
// endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; // endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP;
startAlgo = (int)CUBLAS_GEMM_DEFAULT; startAlgo = (int)CUBLAS_GEMM_DEFAULT;
...@@ -263,7 +264,8 @@ void generate_xlnet_gemm_config(int batch_size, ...@@ -263,7 +264,8 @@ void generate_xlnet_gemm_config(int batch_size,
} }
#endif #endif
using scaleT = typename ScaleTypeConverter<T, false>::Type; // using scaleT = typename ScaleTypeConverter<T, false>::Type;
using scaleT = typename ScaleTypeConverter<T, true>::Type;
scaleT alpha = (scaleT)1.0f; scaleT alpha = (scaleT)1.0f;
scaleT beta = (scaleT)0.0f; scaleT beta = (scaleT)0.0f;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment