Unverified Commit a55002de authored by Yuting Jiang's avatar Yuting Jiang Committed by GitHub
Browse files

Bug: Fix bug for incorrect datatype judgement in cublas-function source code (#464)

**Description**
Fix bug for incorrect datatype judgement in cublas-function source code.
parent c387d9c0
......@@ -170,12 +170,12 @@ class GemmExFunction : public CublasFunction {
* @brief Prepare memory and data of the input and output for kernel running
*/
virtual void prepare_tensor() {
if (this->datatype_.compare("half")) {
if (this->datatype_.compare("half") == 0) {
CublasFunction::prepare_tensor_template<half>(
reinterpret_cast<half **>(&Parameter_0_0), reinterpret_cast<half **>(&Parameter_1_0),
reinterpret_cast<half **>(&Result_3_0), reinterpret_cast<half **>(&Parameter_0_0_host),
reinterpret_cast<half **>(&Parameter_1_0_host));
} else if (this->datatype_.compare("float")) {
} else if (this->datatype_.compare("float") == 0) {
CublasFunction::prepare_tensor_template<float>(
reinterpret_cast<float **>(&Parameter_0_0), reinterpret_cast<float **>(&Parameter_1_0),
reinterpret_cast<float **>(&Result_3_0), reinterpret_cast<float **>(&Parameter_0_0_host),
......@@ -186,11 +186,11 @@ class GemmExFunction : public CublasFunction {
* @brief Function calculation on CPU side
*/
virtual void matrix_calculation_on_cpu() {
if (this->datatype_.compare("half")) {
if (this->datatype_.compare("half") == 0) {
matrix_calculation_on_cpu_with_data(
reinterpret_cast<half *>(Parameter_0_0_host), reinterpret_cast<half *>(Parameter_1_0_host),
reinterpret_cast<half *>(Result_3_0), reinterpret_cast<float **>(&Result_cpu));
} else if (this->datatype_.compare("float")) {
} else if (this->datatype_.compare("float") == 0) {
matrix_calculation_on_cpu_with_data(
reinterpret_cast<float *>(Parameter_0_0_host), reinterpret_cast<float *>(Parameter_1_0_host),
reinterpret_cast<float *>(Result_3_0), reinterpret_cast<float **>(&Result_cpu));
......@@ -201,11 +201,11 @@ class GemmExFunction : public CublasFunction {
*/
virtual int correctness_check() {
int result = 0;
if (this->datatype_.compare("half")) {
if (this->datatype_.compare("half") == 0) {
double eps = this->eps == 0.0 ? 1.e-3 : this->eps;
result = check_result(this->batch_count_, reinterpret_cast<half *>(Result_3_0),
reinterpret_cast<float *>(Result_cpu), eps);
} else if (this->datatype_.compare("float")) {
} else if (this->datatype_.compare("float") == 0) {
double eps = this->eps == 0.0 ? 1.e-6 : this->eps;
result = check_result(this->batch_count_, reinterpret_cast<float *>(Result_3_0),
reinterpret_cast<float *>(Result_cpu), eps);
......@@ -266,12 +266,12 @@ class GemmStridedBatchedExFunction : public CublasFunction {
* @brief Prepare memory and data of the input and output for kernel running
*/
virtual void prepare_tensor() {
if (this->datatype_.compare("half")) {
if (this->datatype_.compare("half") == 0) {
prepare_tensor_template<half>(
reinterpret_cast<half **>(&Parameter_0_0), reinterpret_cast<half **>(&Parameter_1_0),
reinterpret_cast<half **>(&Result_3_0), reinterpret_cast<half **>(&Parameter_0_0_host),
reinterpret_cast<half **>(&Parameter_1_0_host));
} else if (this->datatype_.compare("float")) {
} else if (this->datatype_.compare("float") == 0) {
prepare_tensor_template<float>(
reinterpret_cast<float **>(&Parameter_0_0), reinterpret_cast<float **>(&Parameter_1_0),
reinterpret_cast<float **>(&Result_3_0), reinterpret_cast<float **>(&Parameter_0_0_host),
......@@ -282,11 +282,11 @@ class GemmStridedBatchedExFunction : public CublasFunction {
* @brief Function calculation on CPU side
*/
virtual void matrix_calculation_on_cpu() {
if (this->datatype_.compare("half")) {
if (this->datatype_.compare("half") == 0) {
matrix_calculation_on_cpu_with_data(
reinterpret_cast<half *>(Parameter_0_0_host), reinterpret_cast<half *>(Parameter_1_0_host),
reinterpret_cast<half *>(Result_3_0), reinterpret_cast<float **>(&Result_cpu));
} else if (this->datatype_.compare("float"), 1.0f, 1.0f) {
reinterpret_cast<half *>(Result_3_0), reinterpret_cast<float **>(&Result_cpu), 1.0f, 1.0f);
} else if (this->datatype_.compare("float") == 0) {
matrix_calculation_on_cpu_with_data(
reinterpret_cast<float *>(Parameter_0_0_host), reinterpret_cast<float *>(Parameter_1_0_host),
reinterpret_cast<float *>(Result_3_0), reinterpret_cast<float **>(&Result_cpu), 1.0f, 1.0f);
......@@ -297,11 +297,11 @@ class GemmStridedBatchedExFunction : public CublasFunction {
*/
virtual int correctness_check() {
int result = 0;
if (this->datatype_.compare("half")) {
if (this->datatype_.compare("half") == 0) {
double eps = this->eps == 0.0 ? 1.e-3 : this->eps;
result = check_result(this->batch_count_, reinterpret_cast<half *>(Result_3_0),
reinterpret_cast<float *>(Result_cpu), eps);
} else if (this->datatype_.compare("float")) {
} else if (this->datatype_.compare("float") == 0) {
double eps = this->eps == 0.0 ? 1.e-6 : this->eps;
result = check_result(this->batch_count_, reinterpret_cast<float *>(Result_3_0),
reinterpret_cast<float *>(Result_cpu), eps);
......
......@@ -175,13 +175,25 @@ void from_json(const json &j, CublasFunction &fn) {
fn.set_batch_count(1);
}
try {
auto datatype = j.at("datatype").get<std::string>();
fn.set_datatype(datatype);
auto use_tensor_core = j.at("use_tensor_core").get<bool>();
fn.set_use_tensor_core(use_tensor_core);
if (str.find("datatype") == std::string::npos) {
fn.set_datatype("unknown");
} else {
auto datatype = j.at("datatype").get<std::string>();
fn.set_datatype(datatype);
}
} catch (std::exception &e) {
throw std::runtime_error("datatype is required for cublasGemmEx and cublasGemmStridedBatchedEx");
}
try {
if (str.find("use_tensor_core") == std::string::npos) {
fn.set_use_tensor_core(false);
} else {
auto use_tensor_core = j.at("use_tensor_core").get<bool>();
fn.set_use_tensor_core(use_tensor_core);
}
} catch (std::exception &e) {
fn.set_datatype("float");
fn.set_use_tensor_core(false);
throw std::runtime_error("use_tensor_core is required for cublasGemmEx and cublasGemmStridedBatchedEx");
}
}
......
......@@ -116,10 +116,10 @@ void gemmEx(cublasHandle_t handle, int transa, int transb, int m, int n, int k,
cudaDataType_t matrix_type;
cublasGemmAlgo_t algo;
algo = (use_tensor_core ? CUBLAS_GEMM_DFALT_TENSOR_OP : CUBLAS_GEMM_DFALT);
if (type.compare("float")) {
if (type.compare("float") == 0) {
matrix_type = CUDA_R_32F;
} else {
if (type.compare("half")) {
if (type.compare("half") == 0) {
matrix_type = CUDA_R_16F;
} else {
throw "invalid datatype";
......@@ -153,10 +153,10 @@ void gemmStridedBatchedEx(cublasHandle_t handle, int transa, int transb, int m,
cudaDataType_t matrix_type;
cublasGemmAlgo_t algo;
algo = (use_tensor_core ? CUBLAS_GEMM_DFALT_TENSOR_OP : CUBLAS_GEMM_DFALT);
if (type.compare("float")) {
if (type.compare("float") == 0) {
matrix_type = CUDA_R_32F;
} else {
if (type.compare("half")) {
if (type.compare("half") == 0) {
matrix_type = CUDA_R_16F;
} else {
throw "invalid datatype";
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment