"tests/git@developer.sourcefind.cn:tsoc/superbenchmark.git" did not exist on "4fa60be7cd4c0ad6161a0b01b7e239dcf76270ec"
Unverified Commit a55002de authored by Yuting Jiang's avatar Yuting Jiang Committed by GitHub
Browse files

Bug: Fix bug for incorrect datatype judgement in cublas-function source code (#464)

**Description**
Fix bug for incorrect datatype judgement in cublas-function source code.
parent c387d9c0
...@@ -170,12 +170,12 @@ class GemmExFunction : public CublasFunction { ...@@ -170,12 +170,12 @@ class GemmExFunction : public CublasFunction {
* @brief Prepare memory and data of the input and output for kernel running * @brief Prepare memory and data of the input and output for kernel running
*/ */
virtual void prepare_tensor() { virtual void prepare_tensor() {
if (this->datatype_.compare("half")) { if (this->datatype_.compare("half") == 0) {
CublasFunction::prepare_tensor_template<half>( CublasFunction::prepare_tensor_template<half>(
reinterpret_cast<half **>(&Parameter_0_0), reinterpret_cast<half **>(&Parameter_1_0), reinterpret_cast<half **>(&Parameter_0_0), reinterpret_cast<half **>(&Parameter_1_0),
reinterpret_cast<half **>(&Result_3_0), reinterpret_cast<half **>(&Parameter_0_0_host), reinterpret_cast<half **>(&Result_3_0), reinterpret_cast<half **>(&Parameter_0_0_host),
reinterpret_cast<half **>(&Parameter_1_0_host)); reinterpret_cast<half **>(&Parameter_1_0_host));
} else if (this->datatype_.compare("float")) { } else if (this->datatype_.compare("float") == 0) {
CublasFunction::prepare_tensor_template<float>( CublasFunction::prepare_tensor_template<float>(
reinterpret_cast<float **>(&Parameter_0_0), reinterpret_cast<float **>(&Parameter_1_0), reinterpret_cast<float **>(&Parameter_0_0), reinterpret_cast<float **>(&Parameter_1_0),
reinterpret_cast<float **>(&Result_3_0), reinterpret_cast<float **>(&Parameter_0_0_host), reinterpret_cast<float **>(&Result_3_0), reinterpret_cast<float **>(&Parameter_0_0_host),
...@@ -186,11 +186,11 @@ class GemmExFunction : public CublasFunction { ...@@ -186,11 +186,11 @@ class GemmExFunction : public CublasFunction {
* @brief Function calculation on CPU side * @brief Function calculation on CPU side
*/ */
virtual void matrix_calculation_on_cpu() { virtual void matrix_calculation_on_cpu() {
if (this->datatype_.compare("half")) { if (this->datatype_.compare("half") == 0) {
matrix_calculation_on_cpu_with_data( matrix_calculation_on_cpu_with_data(
reinterpret_cast<half *>(Parameter_0_0_host), reinterpret_cast<half *>(Parameter_1_0_host), reinterpret_cast<half *>(Parameter_0_0_host), reinterpret_cast<half *>(Parameter_1_0_host),
reinterpret_cast<half *>(Result_3_0), reinterpret_cast<float **>(&Result_cpu)); reinterpret_cast<half *>(Result_3_0), reinterpret_cast<float **>(&Result_cpu));
} else if (this->datatype_.compare("float")) { } else if (this->datatype_.compare("float") == 0) {
matrix_calculation_on_cpu_with_data( matrix_calculation_on_cpu_with_data(
reinterpret_cast<float *>(Parameter_0_0_host), reinterpret_cast<float *>(Parameter_1_0_host), reinterpret_cast<float *>(Parameter_0_0_host), reinterpret_cast<float *>(Parameter_1_0_host),
reinterpret_cast<float *>(Result_3_0), reinterpret_cast<float **>(&Result_cpu)); reinterpret_cast<float *>(Result_3_0), reinterpret_cast<float **>(&Result_cpu));
...@@ -201,11 +201,11 @@ class GemmExFunction : public CublasFunction { ...@@ -201,11 +201,11 @@ class GemmExFunction : public CublasFunction {
*/ */
virtual int correctness_check() { virtual int correctness_check() {
int result = 0; int result = 0;
if (this->datatype_.compare("half")) { if (this->datatype_.compare("half") == 0) {
double eps = this->eps == 0.0 ? 1.e-3 : this->eps; double eps = this->eps == 0.0 ? 1.e-3 : this->eps;
result = check_result(this->batch_count_, reinterpret_cast<half *>(Result_3_0), result = check_result(this->batch_count_, reinterpret_cast<half *>(Result_3_0),
reinterpret_cast<float *>(Result_cpu), eps); reinterpret_cast<float *>(Result_cpu), eps);
} else if (this->datatype_.compare("float")) { } else if (this->datatype_.compare("float") == 0) {
double eps = this->eps == 0.0 ? 1.e-6 : this->eps; double eps = this->eps == 0.0 ? 1.e-6 : this->eps;
result = check_result(this->batch_count_, reinterpret_cast<float *>(Result_3_0), result = check_result(this->batch_count_, reinterpret_cast<float *>(Result_3_0),
reinterpret_cast<float *>(Result_cpu), eps); reinterpret_cast<float *>(Result_cpu), eps);
...@@ -266,12 +266,12 @@ class GemmStridedBatchedExFunction : public CublasFunction { ...@@ -266,12 +266,12 @@ class GemmStridedBatchedExFunction : public CublasFunction {
* @brief Prepare memory and data of the input and output for kernel running * @brief Prepare memory and data of the input and output for kernel running
*/ */
virtual void prepare_tensor() { virtual void prepare_tensor() {
if (this->datatype_.compare("half")) { if (this->datatype_.compare("half") == 0) {
prepare_tensor_template<half>( prepare_tensor_template<half>(
reinterpret_cast<half **>(&Parameter_0_0), reinterpret_cast<half **>(&Parameter_1_0), reinterpret_cast<half **>(&Parameter_0_0), reinterpret_cast<half **>(&Parameter_1_0),
reinterpret_cast<half **>(&Result_3_0), reinterpret_cast<half **>(&Parameter_0_0_host), reinterpret_cast<half **>(&Result_3_0), reinterpret_cast<half **>(&Parameter_0_0_host),
reinterpret_cast<half **>(&Parameter_1_0_host)); reinterpret_cast<half **>(&Parameter_1_0_host));
} else if (this->datatype_.compare("float")) { } else if (this->datatype_.compare("float") == 0) {
prepare_tensor_template<float>( prepare_tensor_template<float>(
reinterpret_cast<float **>(&Parameter_0_0), reinterpret_cast<float **>(&Parameter_1_0), reinterpret_cast<float **>(&Parameter_0_0), reinterpret_cast<float **>(&Parameter_1_0),
reinterpret_cast<float **>(&Result_3_0), reinterpret_cast<float **>(&Parameter_0_0_host), reinterpret_cast<float **>(&Result_3_0), reinterpret_cast<float **>(&Parameter_0_0_host),
...@@ -282,11 +282,11 @@ class GemmStridedBatchedExFunction : public CublasFunction { ...@@ -282,11 +282,11 @@ class GemmStridedBatchedExFunction : public CublasFunction {
* @brief Function calculation on CPU side * @brief Function calculation on CPU side
*/ */
virtual void matrix_calculation_on_cpu() { virtual void matrix_calculation_on_cpu() {
if (this->datatype_.compare("half")) { if (this->datatype_.compare("half") == 0) {
matrix_calculation_on_cpu_with_data( matrix_calculation_on_cpu_with_data(
reinterpret_cast<half *>(Parameter_0_0_host), reinterpret_cast<half *>(Parameter_1_0_host), reinterpret_cast<half *>(Parameter_0_0_host), reinterpret_cast<half *>(Parameter_1_0_host),
reinterpret_cast<half *>(Result_3_0), reinterpret_cast<float **>(&Result_cpu)); reinterpret_cast<half *>(Result_3_0), reinterpret_cast<float **>(&Result_cpu), 1.0f, 1.0f);
} else if (this->datatype_.compare("float"), 1.0f, 1.0f) { } else if (this->datatype_.compare("float") == 0) {
matrix_calculation_on_cpu_with_data( matrix_calculation_on_cpu_with_data(
reinterpret_cast<float *>(Parameter_0_0_host), reinterpret_cast<float *>(Parameter_1_0_host), reinterpret_cast<float *>(Parameter_0_0_host), reinterpret_cast<float *>(Parameter_1_0_host),
reinterpret_cast<float *>(Result_3_0), reinterpret_cast<float **>(&Result_cpu), 1.0f, 1.0f); reinterpret_cast<float *>(Result_3_0), reinterpret_cast<float **>(&Result_cpu), 1.0f, 1.0f);
...@@ -297,11 +297,11 @@ class GemmStridedBatchedExFunction : public CublasFunction { ...@@ -297,11 +297,11 @@ class GemmStridedBatchedExFunction : public CublasFunction {
*/ */
virtual int correctness_check() { virtual int correctness_check() {
int result = 0; int result = 0;
if (this->datatype_.compare("half")) { if (this->datatype_.compare("half") == 0) {
double eps = this->eps == 0.0 ? 1.e-3 : this->eps; double eps = this->eps == 0.0 ? 1.e-3 : this->eps;
result = check_result(this->batch_count_, reinterpret_cast<half *>(Result_3_0), result = check_result(this->batch_count_, reinterpret_cast<half *>(Result_3_0),
reinterpret_cast<float *>(Result_cpu), eps); reinterpret_cast<float *>(Result_cpu), eps);
} else if (this->datatype_.compare("float")) { } else if (this->datatype_.compare("float") == 0) {
double eps = this->eps == 0.0 ? 1.e-6 : this->eps; double eps = this->eps == 0.0 ? 1.e-6 : this->eps;
result = check_result(this->batch_count_, reinterpret_cast<float *>(Result_3_0), result = check_result(this->batch_count_, reinterpret_cast<float *>(Result_3_0),
reinterpret_cast<float *>(Result_cpu), eps); reinterpret_cast<float *>(Result_cpu), eps);
......
...@@ -175,13 +175,25 @@ void from_json(const json &j, CublasFunction &fn) { ...@@ -175,13 +175,25 @@ void from_json(const json &j, CublasFunction &fn) {
fn.set_batch_count(1); fn.set_batch_count(1);
} }
try { try {
auto datatype = j.at("datatype").get<std::string>(); if (str.find("datatype") == std::string::npos) {
fn.set_datatype(datatype); fn.set_datatype("unknown");
auto use_tensor_core = j.at("use_tensor_core").get<bool>(); } else {
fn.set_use_tensor_core(use_tensor_core); auto datatype = j.at("datatype").get<std::string>();
fn.set_datatype(datatype);
}
} catch (std::exception &e) {
throw std::runtime_error("datatype is required for cublasGemmEx and cublasGemmStridedBatchedEx");
}
try {
if (str.find("use_tensor_core") == std::string::npos) {
fn.set_use_tensor_core(false);
} else {
auto use_tensor_core = j.at("use_tensor_core").get<bool>();
fn.set_use_tensor_core(use_tensor_core);
}
} catch (std::exception &e) { } catch (std::exception &e) {
fn.set_datatype("float"); throw std::runtime_error("use_tensor_core is required for cublasGemmEx and cublasGemmStridedBatchedEx");
fn.set_use_tensor_core(false);
} }
} }
......
...@@ -116,10 +116,10 @@ void gemmEx(cublasHandle_t handle, int transa, int transb, int m, int n, int k, ...@@ -116,10 +116,10 @@ void gemmEx(cublasHandle_t handle, int transa, int transb, int m, int n, int k,
cudaDataType_t matrix_type; cudaDataType_t matrix_type;
cublasGemmAlgo_t algo; cublasGemmAlgo_t algo;
algo = (use_tensor_core ? CUBLAS_GEMM_DFALT_TENSOR_OP : CUBLAS_GEMM_DFALT); algo = (use_tensor_core ? CUBLAS_GEMM_DFALT_TENSOR_OP : CUBLAS_GEMM_DFALT);
if (type.compare("float")) { if (type.compare("float") == 0) {
matrix_type = CUDA_R_32F; matrix_type = CUDA_R_32F;
} else { } else {
if (type.compare("half")) { if (type.compare("half") == 0) {
matrix_type = CUDA_R_16F; matrix_type = CUDA_R_16F;
} else { } else {
throw "invalid datatype"; throw "invalid datatype";
...@@ -153,10 +153,10 @@ void gemmStridedBatchedEx(cublasHandle_t handle, int transa, int transb, int m, ...@@ -153,10 +153,10 @@ void gemmStridedBatchedEx(cublasHandle_t handle, int transa, int transb, int m,
cudaDataType_t matrix_type; cudaDataType_t matrix_type;
cublasGemmAlgo_t algo; cublasGemmAlgo_t algo;
algo = (use_tensor_core ? CUBLAS_GEMM_DFALT_TENSOR_OP : CUBLAS_GEMM_DFALT); algo = (use_tensor_core ? CUBLAS_GEMM_DFALT_TENSOR_OP : CUBLAS_GEMM_DFALT);
if (type.compare("float")) { if (type.compare("float") == 0) {
matrix_type = CUDA_R_32F; matrix_type = CUDA_R_32F;
} else { } else {
if (type.compare("half")) { if (type.compare("half") == 0) {
matrix_type = CUDA_R_16F; matrix_type = CUDA_R_16F;
} else { } else {
throw "invalid datatype"; throw "invalid datatype";
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment