Unverified Commit 678b1251 authored by Yuting Jiang's avatar Yuting Jiang Committed by GitHub
Browse files

Benchmarks: Micro benchmarks - add source code of correctness check for cublas functions (#450)

**Description**
Add c source code of correctness check for cublas functions.

**Major Revision**
- add correctness check for all supported cublas functions
- add --correctness option into binary

**Minor Revision**
- fix bug and template fill_data and prepare_tensor to get right memory-alignment output matrix for different datatype
parent 9dfefce3
......@@ -202,6 +202,7 @@ def multi_rules(rule, details, categories, store_values):
categories (set): categories of violated rules
store_values (dict): including the number of the metrics that violate the rule, and the values of
the metrics for the rules with 'store' True
Returns:
number: 0 if the rule is passed, otherwise 1
"""
......
......@@ -89,8 +89,9 @@ def _format_summary_of_rule(self, category, summary_df_of_rule, statistics):
Args:
category (str): category in the rule
summary_df_of_rule ([type]): summary df of a rule, the columns are metrics, the index are statistics
summary_df_of_rule (DataFrame): summary df of a rule, the columns are metrics, the index are statistics
statistics (list): statistics in the rule
Returns:
list: list of summary lines like [category, metric, statistic, value]
"""
......
......@@ -14,9 +14,13 @@
* @brief Class of SgemmFunction
*/
class SgemmFunction : public CublasFunction {
float *Parameter_0_0; ///< the pointer of the first input data
float *Parameter_1_0; ///< the pointer of the second input data
float *Result_3_0; ///< the pointer of output data
float *Parameter_0_0; ///< the pointer of the first input data
float *Parameter_1_0; ///< the pointer of the second input data
float *Result_3_0; ///< the pointer of output data
float *Parameter_0_0_host; ///< the pointer of the first input data on host
float *Parameter_1_0_host; ///< the pointer of the second input data on host
float *Result_cpu;
/**
* @brief Execute the kernel/function
*/
......@@ -25,10 +29,26 @@ class SgemmFunction : public CublasFunction {
reinterpret_cast<const float *>(Parameter_0_0), reinterpret_cast<const float *>(Parameter_1_0),
reinterpret_cast<float *>(Result_3_0));
}
/**
* @brief Function calculation on CPU side
*/
virtual void matrix_calculation_on_cpu() {
matrix_calculation_on_cpu_with_data(Parameter_0_0_host, Parameter_1_0_host, Result_3_0, &Result_cpu, 1.0f,
1.0f);
}
/**
* @brief Prepare memory and data of the input and output for kernel running
*/
virtual void prepare_tensor() { CublasFunction::prepare_tensor_float(&Parameter_0_0, &Parameter_1_0, &Result_3_0); }
virtual void prepare_tensor() {
prepare_tensor_template(&Parameter_0_0, &Parameter_1_0, &Result_3_0, &Parameter_0_0_host, &Parameter_1_0_host);
}
/**
* @brief Check the correctness of function calculation result
*/
virtual int correctness_check() {
double eps = this->eps == 0.0 ? 1.e-6 : this->eps;
return check_result(1, Result_3_0, Result_cpu, eps);
}
public:
/**
......@@ -54,6 +74,8 @@ class SgemmFunction : public CublasFunction {
CUDA_SAFE_CALL(cudaFree(Parameter_0_0));
CUDA_SAFE_CALL(cudaFree(Parameter_1_0));
CUDA_SAFE_CALL(cudaFree(Result_3_0));
CUDA_SAFE_CALL(cudaFreeHost(Parameter_0_0_host));
CUDA_SAFE_CALL(cudaFreeHost(Parameter_1_0_host));
cuda_free(&cublas_handle);
}
};
......@@ -65,6 +87,9 @@ class CgemmFunction : public CublasFunction {
cuComplex *Parameter_0_0;
cuComplex *Parameter_1_0;
cuComplex *Result_3_0;
cuComplex *Parameter_0_0_host;
cuComplex *Parameter_1_0_host;
std::complex<float> *Result_cpu;
/**
* @brief Execute the kernel/function
*/
......@@ -73,11 +98,24 @@ class CgemmFunction : public CublasFunction {
reinterpret_cast<const cuComplex *>(Parameter_0_0), reinterpret_cast<const cuComplex *>(Parameter_1_0),
reinterpret_cast<cuComplex *>(Result_3_0));
}
/**
* @brief Function calculation on CPU side
*/
virtual void matrix_calculation_on_cpu() {
matrix_calculation_on_cpu_with_data(Parameter_0_0_host, Parameter_1_0_host, Result_3_0, &Result_cpu);
}
/**
* @brief Prepare memory and data of the input and output for kernel running
*/
virtual void prepare_tensor() {
CublasFunction::prepare_tensor_cucomplex(&Parameter_0_0, &Parameter_1_0, &Result_3_0);
prepare_tensor_template(&Parameter_0_0, &Parameter_1_0, &Result_3_0, &Parameter_0_0_host, &Parameter_1_0_host);
}
/**
* @brief Check the correctness of function calculation result
*/
virtual int correctness_check() {
double eps = this->eps == 0.0 ? 1.e-6 : this->eps;
return check_result(1, Result_3_0, Result_cpu, eps);
}
public:
......@@ -104,6 +142,8 @@ class CgemmFunction : public CublasFunction {
CUDA_SAFE_CALL(cudaFree(Parameter_0_0));
CUDA_SAFE_CALL(cudaFree(Parameter_1_0));
CUDA_SAFE_CALL(cudaFree(Result_3_0));
CUDA_SAFE_CALL(cudaFreeHost(Parameter_0_0_host));
CUDA_SAFE_CALL(cudaFreeHost(Parameter_1_0_host));
cuda_free(&cublas_handle);
}
};
......@@ -112,9 +152,12 @@ class CgemmFunction : public CublasFunction {
* @brief Class of GemmExFunction
*/
class GemmExFunction : public CublasFunction {
float *Parameter_0_0;
float *Parameter_1_0;
float *Result_3_0;
void *Parameter_0_0;
void *Parameter_1_0;
void *Result_3_0;
void *Parameter_0_0_host;
void *Parameter_1_0_host;
void *Result_cpu;
/**
* @brief Execute the kernel/function
*/
......@@ -126,7 +169,49 @@ class GemmExFunction : public CublasFunction {
/**
* @brief Prepare memory and data of the input and output for kernel running
*/
virtual void prepare_tensor() { CublasFunction::prepare_tensor_float(&Parameter_0_0, &Parameter_1_0, &Result_3_0); }
virtual void prepare_tensor() {
if (this->datatype_.compare("half")) {
CublasFunction::prepare_tensor_template<half>(
reinterpret_cast<half **>(&Parameter_0_0), reinterpret_cast<half **>(&Parameter_1_0),
reinterpret_cast<half **>(&Result_3_0), reinterpret_cast<half **>(&Parameter_0_0_host),
reinterpret_cast<half **>(&Parameter_1_0_host));
} else if (this->datatype_.compare("float")) {
CublasFunction::prepare_tensor_template<float>(
reinterpret_cast<float **>(&Parameter_0_0), reinterpret_cast<float **>(&Parameter_1_0),
reinterpret_cast<float **>(&Result_3_0), reinterpret_cast<float **>(&Parameter_0_0_host),
reinterpret_cast<float **>(&Parameter_1_0_host));
}
}
/**
* @brief Function calculation on CPU side
*/
virtual void matrix_calculation_on_cpu() {
if (this->datatype_.compare("half")) {
matrix_calculation_on_cpu_with_data(
reinterpret_cast<half *>(Parameter_0_0_host), reinterpret_cast<half *>(Parameter_1_0_host),
reinterpret_cast<half *>(Result_3_0), reinterpret_cast<float **>(&Result_cpu));
} else if (this->datatype_.compare("float")) {
matrix_calculation_on_cpu_with_data(
reinterpret_cast<float *>(Parameter_0_0_host), reinterpret_cast<float *>(Parameter_1_0_host),
reinterpret_cast<float *>(Result_3_0), reinterpret_cast<float **>(&Result_cpu));
}
}
/**
* @brief Check the correctness of function calculation result
*/
virtual int correctness_check() {
int result = 0;
if (this->datatype_.compare("half")) {
double eps = this->eps == 0.0 ? 1.e-3 : this->eps;
result = check_result(this->batch_count_, reinterpret_cast<half *>(Result_3_0),
reinterpret_cast<float *>(Result_cpu), eps);
} else if (this->datatype_.compare("float")) {
double eps = this->eps == 0.0 ? 1.e-6 : this->eps;
result = check_result(this->batch_count_, reinterpret_cast<float *>(Result_3_0),
reinterpret_cast<float *>(Result_cpu), eps);
}
return result;
}
public:
/**
......@@ -152,6 +237,8 @@ class GemmExFunction : public CublasFunction {
CUDA_SAFE_CALL(cudaFree(Parameter_0_0));
CUDA_SAFE_CALL(cudaFree(Parameter_1_0));
CUDA_SAFE_CALL(cudaFree(Result_3_0));
CUDA_SAFE_CALL(cudaFreeHost(Parameter_0_0_host));
CUDA_SAFE_CALL(cudaFreeHost(Parameter_1_0_host));
cuda_free(&cublas_handle);
}
};
......@@ -160,9 +247,12 @@ class GemmExFunction : public CublasFunction {
* @brief Class of GemmStridedBatchedExFunction
*/
class GemmStridedBatchedExFunction : public CublasFunction {
float *Parameter_0_0;
float *Parameter_1_0;
float *Result_3_0;
void *Parameter_0_0;
void *Parameter_1_0;
void *Result_3_0;
void *Parameter_0_0_host;
void *Parameter_1_0_host;
void *Result_cpu;
/**
* @brief Execute the kernel/function
*/
......@@ -175,7 +265,49 @@ class GemmStridedBatchedExFunction : public CublasFunction {
/**
* @brief Prepare memory and data of the input and output for kernel running
*/
virtual void prepare_tensor() { CublasFunction::prepare_tensor_float(&Parameter_0_0, &Parameter_1_0, &Result_3_0); }
virtual void prepare_tensor() {
if (this->datatype_.compare("half")) {
prepare_tensor_template<half>(
reinterpret_cast<half **>(&Parameter_0_0), reinterpret_cast<half **>(&Parameter_1_0),
reinterpret_cast<half **>(&Result_3_0), reinterpret_cast<half **>(&Parameter_0_0_host),
reinterpret_cast<half **>(&Parameter_1_0_host));
} else if (this->datatype_.compare("float")) {
prepare_tensor_template<float>(
reinterpret_cast<float **>(&Parameter_0_0), reinterpret_cast<float **>(&Parameter_1_0),
reinterpret_cast<float **>(&Result_3_0), reinterpret_cast<float **>(&Parameter_0_0_host),
reinterpret_cast<float **>(&Parameter_1_0_host));
}
}
/**
* @brief Function calculation on CPU side
*/
virtual void matrix_calculation_on_cpu() {
if (this->datatype_.compare("half")) {
matrix_calculation_on_cpu_with_data(
reinterpret_cast<half *>(Parameter_0_0_host), reinterpret_cast<half *>(Parameter_1_0_host),
reinterpret_cast<half *>(Result_3_0), reinterpret_cast<float **>(&Result_cpu));
} else if (this->datatype_.compare("float"), 1.0f, 1.0f) {
matrix_calculation_on_cpu_with_data(
reinterpret_cast<float *>(Parameter_0_0_host), reinterpret_cast<float *>(Parameter_1_0_host),
reinterpret_cast<float *>(Result_3_0), reinterpret_cast<float **>(&Result_cpu), 1.0f, 1.0f);
}
}
/**
* @brief Check the correctness of function calculation result
*/
virtual int correctness_check() {
int result = 0;
if (this->datatype_.compare("half")) {
double eps = this->eps == 0.0 ? 1.e-3 : this->eps;
result = check_result(this->batch_count_, reinterpret_cast<half *>(Result_3_0),
reinterpret_cast<float *>(Result_cpu), eps);
} else if (this->datatype_.compare("float")) {
double eps = this->eps == 0.0 ? 1.e-6 : this->eps;
result = check_result(this->batch_count_, reinterpret_cast<float *>(Result_3_0),
reinterpret_cast<float *>(Result_cpu), eps);
}
return result;
}
public:
/**
......@@ -195,6 +327,8 @@ class GemmStridedBatchedExFunction : public CublasFunction {
CUDA_SAFE_CALL(cudaFree(Parameter_0_0));
CUDA_SAFE_CALL(cudaFree(Parameter_1_0));
CUDA_SAFE_CALL(cudaFree(Result_3_0));
CUDA_SAFE_CALL(cudaFreeHost(Parameter_0_0_host));
CUDA_SAFE_CALL(cudaFreeHost(Parameter_1_0_host));
cuda_free(&cublas_handle);
}
};
......@@ -206,6 +340,9 @@ class SgemmStridedBatchedFunction : public CublasFunction {
float *Parameter_0_0;
float *Parameter_1_0;
float *Result_3_0;
float *Parameter_0_0_host;
float *Parameter_1_0_host;
float *Result_cpu;
/**
* @brief Execute the kernel/function
*/
......@@ -218,7 +355,23 @@ class SgemmStridedBatchedFunction : public CublasFunction {
/**
* @brief Prepare memory and data of the input and output for kernel running
*/
virtual void prepare_tensor() { CublasFunction::prepare_tensor_float(&Parameter_0_0, &Parameter_1_0, &Result_3_0); }
virtual void prepare_tensor() {
prepare_tensor_template(&Parameter_0_0, &Parameter_1_0, &Result_3_0, &Parameter_0_0_host, &Parameter_1_0_host);
}
/**
* @brief Function calculation on CPU side
*/
virtual void matrix_calculation_on_cpu() {
matrix_calculation_on_cpu_with_data(Parameter_0_0_host, Parameter_1_0_host, Result_3_0, &Result_cpu, 1.0f,
1.0f);
}
/**
* @brief Check the correctness of function calculation result
*/
virtual int correctness_check() {
double eps = this->eps == 0.0 ? 1.e-6 : this->eps;
return check_result(this->batch_count_, Result_3_0, Result_cpu, eps);
}
public:
/**
......@@ -238,6 +391,8 @@ class SgemmStridedBatchedFunction : public CublasFunction {
CUDA_SAFE_CALL(cudaFree(Parameter_0_0));
CUDA_SAFE_CALL(cudaFree(Parameter_1_0));
CUDA_SAFE_CALL(cudaFree(Result_3_0));
CUDA_SAFE_CALL(cudaFreeHost(Parameter_0_0_host));
CUDA_SAFE_CALL(cudaFreeHost(Parameter_1_0_host));
cuda_free(&cublas_handle);
}
};
......@@ -249,6 +404,9 @@ class Cgemm3mStridedBatchedFunction : public CublasFunction {
cuComplex *Parameter_0_0;
cuComplex *Parameter_1_0;
cuComplex *Result_3_0;
cuComplex *Parameter_0_0_host;
cuComplex *Parameter_1_0_host;
std::complex<float> *Result_cpu;
/**
* @brief Execute the kernel/function
*/
......@@ -262,7 +420,20 @@ class Cgemm3mStridedBatchedFunction : public CublasFunction {
* @brief Prepare memory and data of the input and output for kernel running
*/
virtual void prepare_tensor() {
CublasFunction::prepare_tensor_cucomplex(&Parameter_0_0, &Parameter_1_0, &Result_3_0);
prepare_tensor_template(&Parameter_0_0, &Parameter_1_0, &Result_3_0, &Parameter_0_0_host, &Parameter_1_0_host);
}
/**
* @brief Function calculation on CPU side
*/
virtual void matrix_calculation_on_cpu() {
matrix_calculation_on_cpu_with_data(Parameter_0_0_host, Parameter_1_0_host, Result_3_0, &Result_cpu);
}
/**
* @brief Check the correctness of function calculation result
*/
virtual int correctness_check() {
double eps = this->eps == 0.0 ? 1.e-6 : this->eps;
return check_result(this->batch_count_, Result_3_0, Result_cpu, eps);
}
public:
......@@ -283,6 +454,8 @@ class Cgemm3mStridedBatchedFunction : public CublasFunction {
CUDA_SAFE_CALL(cudaFree(Parameter_0_0));
CUDA_SAFE_CALL(cudaFree(Parameter_1_0));
CUDA_SAFE_CALL(cudaFree(Result_3_0));
CUDA_SAFE_CALL(cudaFreeHost(Parameter_0_0_host));
CUDA_SAFE_CALL(cudaFreeHost(Parameter_1_0_host));
cuda_free(&cublas_handle);
}
};
......@@ -8,6 +8,7 @@
#pragma once
#include <cstring>
#include <fstream>
#include <iostream>
#include <limits>
......@@ -52,6 +53,18 @@ class Options {
return 0;
}
/**
* @brief Get the double type value of cmd line argument
* @param option the cmd line argument
* @return double the double type value of cmd line argument 'option'
*/
double get_cmd_line_argument_double(const std::string &option) {
if (char *value = get_cmd_option(option)) {
return std::atof(value);
}
return 0.0;
}
/**
* @brief Get the string type value of cmd line argument
* @param option the cmd line argument
......@@ -64,12 +77,27 @@ class Options {
return "";
}
/**
* @brief Get the bool type value of cmd line argument
* @param option the cmd line argument
* @return std::string the int type value of cmd line argument 'option'
*/
bool get_cmd_line_argument_bool(const std::string &option) {
char **itr = std::find(begin, end, option);
if (itr != end) {
return true;
}
return false;
}
public:
int num_test;
int warm_up;
int num_in_step;
int random_seed;
std::string para_info_json;
bool correctness_check;
double eps;
/**
* @brief Construct a options object according to cmd or set a default value used to test
......@@ -90,6 +118,8 @@ class Options {
para_info_json = get_cmd_line_argument_string("--config_json");
para_info_json = para_info_json == "" ? R"({"name":"cublasCgemm","m":512,"n":512,"k":32,"transa":1,"transb":0})"
: para_info_json;
correctness_check = get_cmd_line_argument_bool("--correctness");
eps = get_cmd_line_argument_double("--eps");
}
};
......@@ -197,6 +227,8 @@ void run_benchmark(Options &options) {
function.set_warm_up(options.warm_up);
function.set_num_in_step(options.num_in_step);
function.set_random_seed(options.random_seed);
function.set_correctness(options.correctness_check);
function.set_eps(options.eps);
CublasFunction *p_function = get_cublas_function_pointer(function);
p_function->benchmark();
delete p_function;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment