"vscode:/vscode.git/clone" did not exist on "fcbf7bfce6b6297e2c12c80caf7880014fbf3fba"
Unverified Commit 87f6b371 authored by Yuting Jiang's avatar Yuting Jiang Committed by GitHub
Browse files

Benchmarks: Add benchmark - add source code of cublas function micro benchmark (#77)



* Superbenchmark: Add benchmarks - add cublas function micro benchmark

* format

* add python benchmark for cublas functions, example and test file

* detele python related and rename some files

* revise cmd_helper and move json package to cmake

* resolve conflict

* revise error handing to try-catch and update some code style

* revise cmd_helper.h, cublas_helper.h, cublas_helper.cpp

* revise structure of the cublas function

* add some comments and move cuda_init and cuda_free

* add comments for class member

* add ramdom seed, revise input from file to json string, simplify cmake

* delete json file in source code of cublas

* update according comments

* limit batchcount=1 in initialization of cublas function which do not use batch count

* revise and fix some errors of annotations

* update according comments and revise construction of CublasFunction
Co-authored-by: default avatarroot <root@sb-validation-000001.51z1chmys5fuzfqyo4niepozre.bx.internal.cloudapp.net>
parent 4489388c
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
cmake_minimum_required(VERSION 3.18)
project(CublasBenchmark LANGUAGES CUDA CXX)
include(../cuda_common.cmake)
SET(SRC "cublas_helper.cpp" CACHE STRING "source file")
SET(TARGET_NAME "cublas_function" CACHE STRING "target name")
find_package(CUDAToolkit REQUIRED)
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} ${NVCC_ARCHS_SUPPORTED}")
add_library(${TARGET_NAME} SHARED ${SRC})
link_directories( ${CUDAToolkit_LIBRARY_DIR} ${CUDAToolkit_TARGET_DIR})
include_directories( ${CUDAToolkit_INCLUDE_DIRS})
include(FetchContent)
FetchContent_Declare(json
GIT_REPOSITORY https://github.com/ArthurSonzogni/nlohmann_json_cmake_fetchcontent
GIT_TAG v3.7.3)
FetchContent_GetProperties(json)
if(NOT json_POPULATED)
FetchContent_Populate(json)
add_subdirectory(${json_SOURCE_DIR} ${json_BINARY_DIR} EXCLUDE_FROM_ALL)
endif()
add_executable(CublasBenchmark cublas_test.cpp)
target_link_libraries(CublasBenchmark ${TARGET_NAME} nlohmann_json::nlohmann_json CUDA::cudart CUDA::cublas)
install(TARGETS CublasBenchmark ${TARGET_NAME} RUNTIME DESTINATION bin LIBRARY DESTINATION lib)
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
/**
* @file cublas_benchmark.h
* @brief Unify a base class for cublas function benchmark
*/
#pragma once
#include <chrono>
#include <iostream>
#include <stdexcept>
#include <stdlib.h>
#include <time.h>
#include <unordered_map>
#include <vector>
#include "cublas_helper.h"
/**
* @brief Enum of cublas function name
*/
enum cublas_function_name_enum {
e_cublasSgemm = 0,
e_cublasCgemm,
e_cublasGemmEx,
e_cublasGemmStridedBatchedEx,
e_cublasSgemmStridedBatched,
e_cublasCgemm3mStridedBatched
};
/**
* @brief Map from cublas function name to cublas function name enum
*/
static std::unordered_map<std::string, cublas_function_name_enum> const cublas_function_name_string = {
{"cublasSgemm", cublas_function_name_enum::e_cublasSgemm},
{"cublasCgemm", cublas_function_name_enum::e_cublasCgemm},
{"cublasGemmEx", cublas_function_name_enum::e_cublasGemmEx},
{"cublasGemmStridedBatchedEx", cublas_function_name_enum::e_cublasGemmStridedBatchedEx},
{"cublasSgemmStridedBatched", cublas_function_name_enum::e_cublasSgemmStridedBatched},
{"cublasCgemm3mStridedBatched", cublas_function_name_enum::e_cublasCgemm3mStridedBatched},
};
/**
* @brief Class to store params of cublas function and run the benchmark of this function
*/
class CublasFunction {
protected:
int num_test; ///< the number of steps used to test and measure
int warm_up; ///< the number of steps used to warm up
int num_in_step; ///< the number of functions invoking in a step
int random_seed; ///< the random seed used to generate random data
std::string name_; ///< the name of the cublas function
int m_; ///< the m dim of matrix
int k_; ///< the k dim of matrix
int n_; ///< the n dim of matrix
int transa_; ///< whether the first matrix transpose
int transb_; ///< whether the second matrix transpose
std::string datatype_; ///< data type used in cublasGemmEx and cublasGemmStridedBatchedEx
bool use_tensor_core_; ///< choose the algo used in cublasGemmEx and cublasGemmStridedBatchedEx
int batch_count_; ///< the number of the batch used in some cublas function
cublas_function_name_enum e_name_; ///< enum cublas functin name
std::string function_str_; ///< the str representing the cublas function with params
cublasHandle_t cublas_handle; ///< the handle of cublas function
/**
* @brief Fill the random data into the input in float type
*/
void fill_data_float(float *Parameter_0_0_host, float *Parameter_1_0_host);
/**
* @brief Fill the random data into the input in cuComplex type
*/
void fill_data_cucomplex(cuComplex *Parameter_0_0_host, cuComplex *Parameter_1_0_host);
/**
* @brief Prepare memory and data of the input and output in float type
*/
void prepare_tensor_float(float **Parameter_0_0, float **Parameter_1_0, float **Result_3_0);
/**
* @brief Prepare memory and data of the input and output in cuComplex type
*/
void prepare_tensor_cucomplex(cuComplex **Parameter_0_0, cuComplex **Parameter_1_0, cuComplex **Result_3_0);
/**
* @brief Prepare memory and data of the input and output for kernel running
*/
virtual void prepare_tensor() {}
/**
* @brief Execute the kernel/function
*/
virtual void kernel_entry() {}
public:
/**
* @brief Set the num test member
* @param num_test the number of steps used to test and measure
*/
void set_num_test(int num_test) { this->num_test = num_test; }
/**
* @brief Set the warm up member
* @param warm_up the number of steps used to warm up
*/
void set_warm_up(int warm_up) { this->warm_up = warm_up; }
/**
* @brief Set the num in step member
* @param num_in_step the number of function invoking in a step
*/
void set_num_in_step(int num_in_step) { this->num_in_step = num_in_step; }
/**
* @brief Set the random seed
* @param random_seed random seed
*/
void set_random_seed(int random_seed) { this->random_seed = random_seed; }
/**
* @brief Set the params string
* @param str the str representing the params of the function
*/
void set_function(std::string &str) { this->function_str_ = str; }
/**
* @brief Set the name member
* @param name the name of the cublas function
*/
void set_name(std::string &name) { this->name_ = name; }
/**
* @brief Set the m
* @param m the m dim of matrix
*/
void set_m(int m) { this->m_ = m; }
/**
* @brief Set the n
* @param n the n dim of matrix
*/
void set_n(int n) { this->n_ = n; }
/**
* @brief Set the k
* @param k the k dim of matrix
*/
void set_k(int k) { this->k_ = k; }
/**
* @brief Set the transa
* @param transa whether the first matrix transpose
*/
void set_transa(int transa) { this->transa_ = transa; }
/**
* @brief Set the transb
* @param transb whether the second matrix transpose
*/
void set_transb(int transb) { this->transb_ = transb; }
/**
* @brief Set the datatype
* @param datatype data type used in cublasGemmEx and cublasGemmStridedBatchedEx
*/
void set_datatype(std::string datatype) { this->datatype_ = datatype; }
/**
* @brief Set the use_tensor_core
* @param use_tensor_core choose the algo used in cublasGemmEx and cublasGemmStridedBatchedEx
*/
void set_use_tensor_core(bool use_tensor_core) { this->use_tensor_core_ = use_tensor_core; }
/**
* @brief Set the batch count
* @param batch_count the num of the batch
*/
void set_batch_count(int batch_count) { this->batch_count_ = batch_count; }
/**
* @brief Get the e name
* @return cublas_function_name_enum
*/
cublas_function_name_enum get_e_name() { return e_name_; }
/**
* @brief Get the name object
* @return std::string name of the function
*/
std::string get_name() { return this->name_; }
/**
* @brief Convert function name to enum type
* @return cublas_function_name_enum
*/
cublas_function_name_enum name2enum() {
auto it = cublas_function_name_string.find(this->name_);
if (it != cublas_function_name_string.end()) {
this->e_name_ = it->second;
return e_name_;
} else {
throw "invalid input function name";
}
}
/**
* @brief The main procedure for cublas function test, includingwarmup, function test, time measurement
* and output raw data results
*/
void benchmark();
/**
* @brief Destroy the Cublas Function object
*/
virtual ~CublasFunction() {}
};
/**
* @brief Fill the random data into the input in cuComplex type
*/
void CublasFunction::fill_data_float(float *Parameter_0_0_host, float *Parameter_1_0_host) {
srand(random_seed);
for (int i = 0; i < m_ * k_; i++) {
Parameter_0_0_host[i] = (float)rand() / (float)(RAND_MAX);
}
for (int i = 0; i < k_ * n_; ++i) {
Parameter_1_0_host[i] = (float)rand() / (float)(RAND_MAX);
}
}
/**
* @brief Fill the random data into the input in cuComplex type
*/
void CublasFunction::fill_data_cucomplex(cuComplex *Parameter_0_0_host, cuComplex *Parameter_1_0_host) {
srand(random_seed);
for (int i = 0; i < m_ * k_; i++) {
Parameter_0_0_host[i] =
make_cuComplex(((float)rand() / (float)(RAND_MAX)), ((float)rand() / (float)(RAND_MAX)));
}
for (int i = 0; i < k_ * n_; ++i) {
Parameter_1_0_host[i] =
make_cuComplex(((float)rand() / (float)(RAND_MAX)), ((float)rand() / (float)(RAND_MAX)));
}
}
/**
* @brief Prepare memory and data of the input and output in float type
*/
void CublasFunction::prepare_tensor_float(float **Parameter_0_0, float **Parameter_1_0, float **Result_3_0) {
int m = this->m_;
int n = this->n_;
int k = this->k_;
float *Parameter_0_0_host, *Parameter_1_0_host;
// input argument
CUDA_SAFE_CALL(cudaMallocHost((void **)&Parameter_0_0_host, sizeof(float) * m * k * this->batch_count_));
CUDA_SAFE_CALL(cudaMalloc((void **)Parameter_0_0, sizeof(float) * m * k * this->batch_count_));
// input argument
CUDA_SAFE_CALL(cudaMallocHost((void **)&Parameter_1_0_host, sizeof(float) * n * k * this->batch_count_));
CUDA_SAFE_CALL(cudaMalloc((void **)Parameter_1_0, sizeof(float) * n * k * this->batch_count_));
// fill input values
fill_data_float(Parameter_0_0_host, Parameter_1_0_host);
// copy input data from host to device
CUDA_SAFE_CALL(cudaMemcpy(*Parameter_0_0, Parameter_0_0_host, sizeof(float) * m * k * this->batch_count_,
cudaMemcpyHostToDevice));
CUDA_SAFE_CALL(cudaMemcpy(*Parameter_1_0, Parameter_1_0_host, sizeof(float) * k * n * this->batch_count_,
cudaMemcpyHostToDevice));
// output arguments
CUDA_SAFE_CALL(cudaMalloc((void **)Result_3_0, sizeof(float) * m * n * batch_count_));
CUDA_SAFE_CALL(cudaMemset((void *)*Result_3_0, 0, sizeof(float) * m * n * batch_count_));
CUDA_SAFE_CALL(cudaFreeHost(Parameter_0_0_host));
CUDA_SAFE_CALL(cudaFreeHost(Parameter_1_0_host));
}
/**
* @brief Prepare memory and data of the input and output in cuComplex type
*/
void CublasFunction::prepare_tensor_cucomplex(cuComplex **Parameter_0_0, cuComplex **Parameter_1_0,
cuComplex **Result_3_0) {
int m = this->m_;
int n = this->n_;
int k = this->k_;
cuComplex *Parameter_0_0_host, *Parameter_1_0_host;
// input argument
CUDA_SAFE_CALL(cudaMallocHost((void **)&Parameter_0_0_host, sizeof(cuComplex) * m * k * this->batch_count_));
CUDA_SAFE_CALL(cudaMalloc((void **)Parameter_0_0, sizeof(cuComplex) * m * k * this->batch_count_));
// input argument
CUDA_SAFE_CALL(cudaMallocHost((void **)&Parameter_1_0_host, sizeof(cuComplex) * n * k * this->batch_count_));
CUDA_SAFE_CALL(cudaMalloc((void **)Parameter_1_0, sizeof(cuComplex) * n * k * this->batch_count_));
// fill input values
fill_data_cucomplex(Parameter_0_0_host, Parameter_1_0_host);
// copy input data from host to device
CUDA_SAFE_CALL(cudaMemcpy(*Parameter_0_0, Parameter_0_0_host, sizeof(cuComplex) * m * k * this->batch_count_,
cudaMemcpyHostToDevice));
CUDA_SAFE_CALL(cudaMemcpy(*Parameter_1_0, Parameter_1_0_host, sizeof(cuComplex) * k * n * this->batch_count_,
cudaMemcpyHostToDevice));
// output arguments
CUDA_SAFE_CALL(cudaMalloc((void **)Result_3_0, sizeof(cuComplex) * m * n * batch_count_));
CUDA_SAFE_CALL(cudaMemset((void *)*Result_3_0, 0, sizeof(cuComplex) * m * n * batch_count_));
CUDA_SAFE_CALL(cudaFreeHost(Parameter_0_0_host));
CUDA_SAFE_CALL(cudaFreeHost(Parameter_1_0_host));
}
/**
* @brief The main procedure for cublas function test, including warmup, function test, time measurement and output raw
* data results
*/
void CublasFunction::benchmark() {
// Malloc memory for input and output data
this->prepare_tensor();
// Warm up
for (int i_ = 0; i_ < warm_up; i_++) {
for (int j = 0; j < num_in_step; j++) {
this->kernel_entry();
}
}
CUDA_SAFE_CALL(cudaDeviceSynchronize());
// Prepare some varibles for time measurement
std::vector<float> iteration_time;
// Benchmark in range of steps
for (int i_ = 0; i_ < num_test; i_++) {
// Collect time within each step, including #repeat_in_one_step times function invoking
auto start = std::chrono::high_resolution_clock::now();
for (int j = 0; j < num_in_step; j++) {
this->kernel_entry();
}
CUDA_SAFE_CALL(cudaDeviceSynchronize());
auto end = std::chrono::high_resolution_clock::now();
// Convert step time to single function duration and update min and max duration
float i = static_cast<float>(std::chrono::duration<double, std::micro>(end - start).count() / num_in_step);
iteration_time.emplace_back(i);
}
// Output results
std::cout << "[function config]: " << this->function_str_ << std::endl;
std::cout << "[raw_data]: ";
for (int i = 0; i < iteration_time.size(); i++) {
std::cout << iteration_time[i] << ",";
}
std::cout << std::endl;
}
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
/**
* @file cublas_function.h
* @brief Implementation of specific cublas function
*/
#pragma once
#include "cublas_benchmark.h"
/**
* @brief Class of SgemmFunction
*/
class SgemmFunction : public CublasFunction {
float *Parameter_0_0; ///< the pointer of the first input data
float *Parameter_1_0; ///< the pointer of the second input data
float *Result_3_0; ///< the pointer of output data
/**
* @brief Execute the kernel/function
*/
virtual void kernel_entry() {
sgemm(cublas_handle, this->transa_, this->transb_, this->m_, this->n_, this->k_,
reinterpret_cast<const float *>(Parameter_0_0), reinterpret_cast<const float *>(Parameter_1_0),
reinterpret_cast<float *>(Result_3_0));
}
/**
* @brief Prepare memory and data of the input and output for kernel running
*/
virtual void prepare_tensor() { CublasFunction::prepare_tensor_float(&Parameter_0_0, &Parameter_1_0, &Result_3_0); }
public:
/**
* @brief Construct a new Sgemm Function object
*/
SgemmFunction() {
this->batch_count_ = 1;
cuda_init(&cublas_handle);
}
/**
* @brief Construct a new Sgemm Function object
* @param function base class CublasFunction object
*/
SgemmFunction(CublasFunction &function) : CublasFunction(function) {
this->batch_count_ = 1;
cuda_init(&cublas_handle);
}
/**
* @brief Destroy the Sgemm Function object
*/
~SgemmFunction() {
// Free contexts
CUDA_SAFE_CALL(cudaFree(Parameter_0_0));
CUDA_SAFE_CALL(cudaFree(Parameter_1_0));
CUDA_SAFE_CALL(cudaFree(Result_3_0));
cuda_free(&cublas_handle);
}
};
/**
* @brief Class of CgemmFunction
*/
class CgemmFunction : public CublasFunction {
cuComplex *Parameter_0_0;
cuComplex *Parameter_1_0;
cuComplex *Result_3_0;
/**
* @brief Execute the kernel/function
*/
virtual void kernel_entry() {
cgemm(cublas_handle, this->transa_, this->transb_, this->m_, this->n_, this->k_,
reinterpret_cast<const cuComplex *>(Parameter_0_0), reinterpret_cast<const cuComplex *>(Parameter_1_0),
reinterpret_cast<cuComplex *>(Result_3_0));
}
/**
* @brief Prepare memory and data of the input and output for kernel running
*/
virtual void prepare_tensor() {
CublasFunction::prepare_tensor_cucomplex(&Parameter_0_0, &Parameter_1_0, &Result_3_0);
}
public:
/**
* @brief Construct a new Cgemm Function object
*/
CgemmFunction() {
this->batch_count_ = 1;
cuda_init(&cublas_handle);
}
/**
* @brief Construct a new Cgemm Function object
* @param function base class CublasFunction object
*/
CgemmFunction(CublasFunction &function) : CublasFunction(function) {
this->batch_count_ = 1;
cuda_init(&cublas_handle);
}
/**
* @brief Destroy the Cgemm Function object
*/
~CgemmFunction() {
// Free contexts
CUDA_SAFE_CALL(cudaFree(Parameter_0_0));
CUDA_SAFE_CALL(cudaFree(Parameter_1_0));
CUDA_SAFE_CALL(cudaFree(Result_3_0));
cuda_free(&cublas_handle);
}
};
/**
* @brief Class of GemmExFunction
*/
class GemmExFunction : public CublasFunction {
float *Parameter_0_0;
float *Parameter_1_0;
float *Result_3_0;
/**
* @brief Execute the kernel/function
*/
virtual void kernel_entry() {
gemmEx(cublas_handle, this->transa_, this->transb_, this->m_, this->n_, this->k_,
reinterpret_cast<void *>(Parameter_0_0), reinterpret_cast<void *>(Parameter_1_0),
reinterpret_cast<void *>(Result_3_0), this->datatype_, this->use_tensor_core_);
}
/**
* @brief Prepare memory and data of the input and output for kernel running
*/
virtual void prepare_tensor() { CublasFunction::prepare_tensor_float(&Parameter_0_0, &Parameter_1_0, &Result_3_0); }
public:
/**
* @brief Construct a new Gemm Ex Function object
*/
GemmExFunction() {
this->batch_count_ = 1;
cuda_init(&cublas_handle);
}
/**
* @brief Construct a new Gemm Ex Function object
* @param function base class CublasFunction object
*/
GemmExFunction(CublasFunction &function) : CublasFunction(function) {
this->batch_count_ = 1;
cuda_init(&cublas_handle);
}
/**
* @brief Destroy the Gemm Ex Function object
*/
~GemmExFunction() {
// Free contexts
CUDA_SAFE_CALL(cudaFree(Parameter_0_0));
CUDA_SAFE_CALL(cudaFree(Parameter_1_0));
CUDA_SAFE_CALL(cudaFree(Result_3_0));
cuda_free(&cublas_handle);
}
};
/**
* @brief Class of GemmStridedBatchedExFunction
*/
class GemmStridedBatchedExFunction : public CublasFunction {
float *Parameter_0_0;
float *Parameter_1_0;
float *Result_3_0;
/**
* @brief Execute the kernel/function
*/
virtual void kernel_entry() {
gemmStridedBatchedEx(cublas_handle, this->transa_, this->transb_, this->m_, this->n_, this->k_,
reinterpret_cast<void *>(Parameter_0_0), reinterpret_cast<void *>(Parameter_1_0),
reinterpret_cast<void *>(Result_3_0), this->datatype_, this->use_tensor_core_,
this->batch_count_);
}
/**
* @brief Prepare memory and data of the input and output for kernel running
*/
virtual void prepare_tensor() { CublasFunction::prepare_tensor_float(&Parameter_0_0, &Parameter_1_0, &Result_3_0); }
public:
/**
* @brief Construct a new Gemm Strided Batched Ex Function object
*/
GemmStridedBatchedExFunction() { cuda_init(&cublas_handle); }
/**
* @brief Construct a new Gemm Strided Batched Ex Function object
* @param function base class CublasFunction object
*/
GemmStridedBatchedExFunction(CublasFunction &function) : CublasFunction(function) { cuda_init(&cublas_handle); }
/**
* @brief Destroy the Gemm Strided Batched Ex Function object
*/
~GemmStridedBatchedExFunction() {
// Free contexts
CUDA_SAFE_CALL(cudaFree(Parameter_0_0));
CUDA_SAFE_CALL(cudaFree(Parameter_1_0));
CUDA_SAFE_CALL(cudaFree(Result_3_0));
cuda_free(&cublas_handle);
}
};
/**
* @brief Class of SgemmStridedBatchedFunction
*/
class SgemmStridedBatchedFunction : public CublasFunction {
float *Parameter_0_0;
float *Parameter_1_0;
float *Result_3_0;
/**
* @brief Execute the kernel/function
*/
virtual void kernel_entry() {
sgemmStridedBatched(cublas_handle, this->transa_, this->transb_, this->m_, this->n_, this->k_,
reinterpret_cast<const float *>(Parameter_0_0),
reinterpret_cast<const float *>(Parameter_1_0), reinterpret_cast<float *>(Result_3_0),
this->batch_count_);
}
/**
* @brief Prepare memory and data of the input and output for kernel running
*/
virtual void prepare_tensor() { CublasFunction::prepare_tensor_float(&Parameter_0_0, &Parameter_1_0, &Result_3_0); }
public:
/**
* @brief Construct a new Sgemm Strided Batched Function object
*/
SgemmStridedBatchedFunction() { cuda_init(&cublas_handle); }
/**
* @brief Construct a new Sgemm Strided Batched Function object
* @param function base class CublasFunction object
*/
SgemmStridedBatchedFunction(CublasFunction &function) : CublasFunction(function) { cuda_init(&cublas_handle); }
/**
* @brief Destroy the Sgemm Strided Batched Function object
*/
~SgemmStridedBatchedFunction() {
// Free contexts
CUDA_SAFE_CALL(cudaFree(Parameter_0_0));
CUDA_SAFE_CALL(cudaFree(Parameter_1_0));
CUDA_SAFE_CALL(cudaFree(Result_3_0));
cuda_free(&cublas_handle);
}
};
/**
* @brief Class of Cgemm3mStridedBatchedFunction
*/
class Cgemm3mStridedBatchedFunction : public CublasFunction {
cuComplex *Parameter_0_0;
cuComplex *Parameter_1_0;
cuComplex *Result_3_0;
/**
* @brief Execute the kernel/function
*/
virtual void kernel_entry() {
cgemm3mStridedBatched(cublas_handle, this->transa_, this->transb_, this->m_, this->n_, this->k_,
reinterpret_cast<const cuComplex *>(Parameter_0_0),
reinterpret_cast<const cuComplex *>(Parameter_1_0),
reinterpret_cast<cuComplex *>(Result_3_0), this->batch_count_);
}
/**
* @brief Prepare memory and data of the input and output for kernel running
*/
virtual void prepare_tensor() {
CublasFunction::prepare_tensor_cucomplex(&Parameter_0_0, &Parameter_1_0, &Result_3_0);
}
public:
/**
* @brief Construct a new Cgemm 3m Strided Batched Function object
*/
Cgemm3mStridedBatchedFunction() { cuda_init(&cublas_handle); }
/**
* @brief Construct a new Cgemm 3m Strided Batched Function object according to base class object
* @param function base class CublasFunction object
*/
Cgemm3mStridedBatchedFunction(CublasFunction &function) : CublasFunction(function) { cuda_init(&cublas_handle); }
/**
* @brief Destroy the Cgemm 3m Strided Batched Function object
*/
~Cgemm3mStridedBatchedFunction() {
// Free contexts
CUDA_SAFE_CALL(cudaFree(Parameter_0_0));
CUDA_SAFE_CALL(cudaFree(Parameter_1_0));
CUDA_SAFE_CALL(cudaFree(Result_3_0));
cuda_free(&cublas_handle);
}
};
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
/**
* @file cublas_function_helper.h
* @brief Helper for parsing command line arguments and pass params to cublas function
*/
#pragma once
#include <fstream>
#include <iostream>
#include <limits>
#include <sstream>
#include <string>
#include <vector>
#include <nlohmann/json.hpp>
#include "cublas_function.h"
using json = nlohmann::json;
/**
* @brief Utility for storing command line arguments
*/
class Options {
char **begin;
char **end;
/**
* @brief Get the char* value of the cmd line argument
* @param option the argument in cmd
* @return char*
*/
char *get_cmd_option(const std::string &option) {
char **itr = std::find(begin, end, option);
if (itr != end && ++itr != end) {
return *itr;
}
return 0;
}
/**
* @brief Get the int type value of cmd line argument
* @param option the cmd line argument
* @return int the int type value of cmd line argument 'option'
*/
int get_cmd_line_argument_int(const std::string &option) {
if (char *value = get_cmd_option(option)) {
return std::stoi(value);
}
return 0;
}
/**
* @brief Get the string type value of cmd line argument
* @param option the cmd line argument
* @return std::string the int type value of cmd line argument 'option'
*/
std::string get_cmd_line_argument_string(const std::string &option) {
if (char *value = get_cmd_option(option)) {
return std::string(value);
}
return "";
}
public:
int num_test;
int warm_up;
int num_in_step;
int random_seed;
std::string para_info_json;
/**
* @brief Construct a options object according to cmd or set a default value used to test
* @param argc
* @param argv
*/
Options(int argc, char *argv[]) {
begin = argv;
end = argv + argc;
num_test = get_cmd_line_argument_int("--num_test");
num_test = (num_test == 0 ? 1 : num_test);
warm_up = get_cmd_line_argument_int("--warm_up");
warm_up = (warm_up == 0 ? 1 : warm_up);
num_in_step = get_cmd_line_argument_int("--num_in_step");
num_in_step = (num_in_step == 0 ? 100 : num_in_step);
random_seed = get_cmd_line_argument_int("--random_seed");
random_seed = (random_seed == 0 ? time(NULL) : random_seed);
para_info_json = get_cmd_line_argument_string("--config_json");
para_info_json = para_info_json == "" ? R"({"name":"cublasCgemm","m":512,"n":512,"k":32,"transa":1,"transb":0})"
: para_info_json;
}
};
/**
* @brief Helper function to convert from json to cublasfunction
*
* The params required for each type of cublas funcion is as below:
*+-----------------------------+----------+----------+----------+----------+----------+------------+----------+-----------------+
*| name | m | n | k | transa | transb | batchCount | datatype |
*use_tensor_core |
*+-----------------------------+----------+----------+----------+----------+----------+------------+----------+-----------------+
*| cublasSgemm | required | required | required | required | required | no | no | no |
*+-----------------------------+----------+----------+----------+----------+----------+------------+----------+-----------------+
*| cublasGemmEx | required | required | required | required | required | no | required |
*required
*|
*+-----------------------------+----------+----------+----------+----------+----------+------------+----------+-----------------+
*| cublasSgemmStridedBatched | required | required | required | required | required | required | no | no |
*+-----------------------------+----------+----------+----------+----------+----------+------------+----------+-----------------+
*| cublasGemmStridedBatchedEx | required | required | required | required | required | required | required |
*required
*|
*+-----------------------------+----------+----------+----------+----------+----------+------------+----------+-----------------+
*| cublasCgemm | required | required | required | required | required | no | no | no |
*+-----------------------------+----------+----------+----------+----------+----------+------------+----------+-----------------+
*| cublasCgemm3mStridedBatched | required | required | required | required | required | required | no | no |
*+-----------------------------+----------+----------+----------+----------+----------+------------+----------+-----------------+
*
* @param j json including the params of a cublas function read from 'config_json'
* @param fn a CublasFunction object
*/
void from_json(const json &j, CublasFunction &fn) {
auto str = j.dump();
std::replace(str.begin(), str.end(), '\"', ' ');
fn.set_function(str);
auto name = j.at("name").get<std::string>();
fn.set_name(name);
auto m = j.at("m").get<int>();
fn.set_m(m);
auto n = j.at("n").get<int>();
fn.set_n(n);
auto k = j.at("k").get<int>();
fn.set_k(k);
auto transa = j.at("transa").get<int>();
fn.set_transa(transa);
auto transb = j.at("transb").get<int>();
fn.set_transb(transb);
fn.name2enum();
try {
auto batch_count = j.at("batchCount").get<int>();
fn.set_batch_count(batch_count);
} catch (std::exception &e) {
fn.set_batch_count(1);
}
try {
auto datatype = j.at("datatype").get<std::string>();
fn.set_datatype(datatype);
auto use_tensor_core = j.at("use_tensor_core").get<bool>();
fn.set_use_tensor_core(use_tensor_core);
} catch (std::exception &e) {
fn.set_datatype("float");
fn.set_use_tensor_core(false);
}
}
/**
* @brief Get the cublas function pointer of a specific child class
* @param function base class object of a CublasFunction, used to initialize the base part of the child class
* object
* @return CublasFunction* return a base cublas function pointer of a specific child class
*/
CublasFunction *get_cublas_function_pointer(CublasFunction &function) {
switch (function.get_e_name()) {
case e_cublasSgemm:
return new SgemmFunction(function);
case e_cublasGemmEx:
return new GemmExFunction(function);
case e_cublasSgemmStridedBatched:
return new SgemmStridedBatchedFunction(function);
case e_cublasGemmStridedBatchedEx:
return new GemmStridedBatchedExFunction(function);
case e_cublasCgemm:
return new CgemmFunction(function);
case e_cublasCgemm3mStridedBatched:
return new Cgemm3mStridedBatchedFunction(function);
default:
throw "invalid function name";
}
}
/**
* @brief run the entire process of benchmark according to cmd auguments
*
* first, read the para_info_json string representing the params for a cublas function
* then get the pointer of the class object the specific cublas function
* finally run the benchmark of the funcion
*
* @param options the cmd arguments of the application
*/
void run_benchmark(Options &options) {
try {
json function_config = json::parse(options.para_info_json);
CublasFunction function = function_config.get<CublasFunction>();
function.set_num_test(options.num_test);
function.set_warm_up(options.warm_up);
function.set_num_in_step(options.num_in_step);
function.set_random_seed(options.random_seed);
CublasFunction *p_function = get_cublas_function_pointer(function);
p_function->benchmark();
std::cout << "~delete" << std::endl;
delete p_function;
} catch (std::exception &e) {
std::cout << "Error: " << e.what() << std::endl;
}
}
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
/**
* @file cublas_helper.cpp
* @brief Cpp file for some functions related to cublas
*/
#include "cublas_benchmark.h"
/**
* @brief check cuda function running status and throw error str
*/
void check_cuda(cudaError_t result, char const *const func, const char *const file, int const line) {
if (result != cudaSuccess) {
const char *msg = cudaGetErrorString(result);
std::stringstream safe_call_ss;
safe_call_ss << func << " failed with error"
<< "\nfile: " << file << "\nline: " << line << "\nmsg: " << msg;
// Make sure we call CUDA Device Reset before exiting
throw std::runtime_error(safe_call_ss.str());
}
}
/**
* @brief check cublas function running status and throw error str
*/
void check_cublas(cublasStatus_t result, char const *const func, const char *const file, int const line) {
if (result != CUBLAS_STATUS_SUCCESS) {
std::stringstream safe_call_ss;
safe_call_ss << func << " failed with error"
<< "\nfile: " << file << "\nline: " << line << "\nmsg: " << result;
// Make sure we call CUDA Device Reset before exiting
throw std::runtime_error(safe_call_ss.str());
}
}
/**
* @brief Cuda context init
*/
void cuda_init(cublasHandle_t *cublas_handle) {
CUDA_SAFE_CALL(cudaDeviceReset());
CUDA_SAFE_CALL(cudaSetDevice(0));
// create streams/handles
CUBLAS_SAFE_CALL(cublasCreate(cublas_handle));
}
/**
* @brief Cuda context free
*/
void cuda_free(cublasHandle_t *cublas_handle) {
CUBLAS_SAFE_CALL(cublasDestroy(*cublas_handle));
CUDA_SAFE_CALL(cudaSetDevice(0));
}
/**
* @brief cublas function of gemm, wrapper of cublasSgemm
* @param handle cublas handle
* @param transa whether matrixA transpose
* @param transb whether matrixB transpose
* @param m m of matrix m*n,n*k
* @param n n of matrix m*n,n*k
* @param k k of matrix m*n,n*k
* @param a input matrixA
* @param b input matrixB
* @param c output matrix
*/
void sgemm(cublasHandle_t handle, int transa, int transb, int m, int n, int k, const float *a, const float *b,
float *c) {
float alpha = 1.0f;
float beta = 1.0f;
CUBLAS_SAFE_CALL(cublasSgemm(handle, (transa ? CUBLAS_OP_T : CUBLAS_OP_N), (transb ? CUBLAS_OP_T : CUBLAS_OP_N), m,
n, k, &alpha, a, (transa ? k : m), b, (transb ? n : k), &beta, c, m));
}
/**
* @brief cublas function of gemm, wrapper of cublasCgemm
* @param handle cublas handle
* @param transa whether matrixA transpose
* @param transb whether matrixB transpose
* @param m m of matrix m*n,n*k
* @param n n of matrix m*n,n*k
* @param k k of matrix m*n,n*k
* @param a input matrixA
* @param b input matrixB
* @param c output matrix
*/
void cgemm(cublasHandle_t handle, int transa, int transb, int m, int n, int k, const cuComplex *a, const cuComplex *b,
cuComplex *c) {
cuComplex alpha = make_cuComplex(1.0f, 0.0f);
cuComplex beta = make_cuComplex(0.0f, 0.0f);
CUBLAS_SAFE_CALL(cublasCgemm(handle, (transa ? CUBLAS_OP_T : CUBLAS_OP_N), (transb ? CUBLAS_OP_T : CUBLAS_OP_N), m,
n, k, &alpha, a, (transa ? k : m), b, (transb ? n : k), &beta, c, m));
}
/**
* @brief cublas function of GemmEx, wrapper of cublasGemmEx
* @param handle cublas handle
* @param transa whether matrixA transpose
* @param transb whether matrixB transpose
* @param m m of matrix m*n,n*k
* @param n n of matrix m*n,n*k
* @param k k of matrix m*n,n*k
* @param a input matrixA
* @param b input matrixB
* @param c output matrix
* @param type matrix type, 'float' or 'half'
* @param use_tensor_core whether use tensor core
*/
void gemmEx(cublasHandle_t handle, int transa, int transb, int m, int n, int k, const void *a, const void *b, void *c,
std::string type, bool use_tensor_core) {
float alpha = 1.0f;
float beta = 0.0f;
cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F;
cudaDataType_t matrix_type;
cublasGemmAlgo_t algo;
algo = (use_tensor_core ? CUBLAS_GEMM_DFALT_TENSOR_OP : CUBLAS_GEMM_DFALT);
if (type.compare("float")) {
matrix_type = CUDA_R_32F;
} else {
if (type.compare("half")) {
matrix_type = CUDA_R_16F;
} else {
throw "invalid datatype";
}
}
CUBLAS_SAFE_CALL(cublasGemmEx(handle, (transa ? CUBLAS_OP_T : CUBLAS_OP_N), (transb ? CUBLAS_OP_T : CUBLAS_OP_N), m,
n, k, &alpha, a, matrix_type, (transa ? k : m), b, matrix_type, (transb ? n : k),
&beta, c, matrix_type, m, compute_type, algo));
}
/**
* @brief cublas function of gemmStridedBatchedEx, wrapper of cublasGemmStridedBatchedEx
* @param handle cublas handle
* @param transa whether matrixA transpose
* @param transb whether matrixB transpose
* @param m m of matrix m*n,n*k
* @param n n of matrix m*n,n*k
* @param k k of matrix m*n,n*k
* @param a input matrixA
* @param b input matrixB
* @param c output matrix
* @param type matrix type, 'float' or 'half'
* @param use_tensor_core whether use tensor core
* @param batchCount My Param doc
*/
void gemmStridedBatchedEx(cublasHandle_t handle, int transa, int transb, int m, int n, int k, const void *a,
const void *b, void *c, std::string type, bool use_tensor_core, int batchCount) {
float alpha = 1.0f;
float beta = 1.0f;
cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F;
cudaDataType_t matrix_type;
cublasGemmAlgo_t algo;
algo = (use_tensor_core ? CUBLAS_GEMM_DFALT_TENSOR_OP : CUBLAS_GEMM_DFALT);
if (type.compare("float")) {
matrix_type = CUDA_R_32F;
} else {
if (type.compare("half")) {
matrix_type = CUDA_R_16F;
} else {
throw "invalid datatype";
}
}
CUBLAS_SAFE_CALL(cublasGemmStridedBatchedEx(handle, (transa ? CUBLAS_OP_T : CUBLAS_OP_N),
(transb ? CUBLAS_OP_T : CUBLAS_OP_N), m, n, k, &alpha, a, matrix_type,
(transa ? k : m), m * k, b, matrix_type, (transb ? n : k), n * k, &beta,
c, matrix_type, m, m * n, batchCount, compute_type, algo));
}
/**
* @brief cublas function of gemmStridedBatchedEx, wrapper of cublasGemmStridedBatchedEx
* @param handle cublas handle
* @param transa whether matrixA transpose
* @param transb whether matrixB transpose
* @param m m of matrix m*n,n*k
* @param n n of matrix m*n,n*k
* @param k k of matrix m*n,n*k
* @param a input matrixA
* @param b input matrixB
* @param c output matrix
* @param batchCount the count of batch used to compute
*/
void sgemmStridedBatched(cublasHandle_t handle, int transa, int transb, int m, int n, int k, const float *a,
const float *b, float *c, int batchCount) {
float alpha = 1.0f;
float beta = 1.0f;
CUBLAS_SAFE_CALL(cublasSgemmStridedBatched(
handle, (transa ? CUBLAS_OP_T : CUBLAS_OP_N), (transb ? CUBLAS_OP_T : CUBLAS_OP_N), m, n, k, &alpha, a,
(transa ? k : m), m * k, b, (transb ? n : k), n * k, &beta, c, m, m * n, batchCount));
}
/**
* @brief
* @brief cublas function of sgemmStridedBatched, wrapper of cublasSgemmStridedBatched
* @param handle cublas handle
* @param transa whether matrixA transpose
* @param transb whether matrixB transpose
* @param m m of matrix m*n,n*k
* @param n n of matrix m*n,n*k
* @param k k of matrix m*n,n*k
* @param a input matrixA
* @param b input matrixB
* @param c output matrix
* @param batchCount the count of batch used to compute
*/
void cgemm3mStridedBatched(cublasHandle_t handle, int transa, int transb, int m, int n, int k, const cuComplex *a,
const cuComplex *b, cuComplex *c, int batchCount) {
cuComplex alpha = make_cuComplex(1.0f, 0.0f);
cuComplex beta = make_cuComplex(0.0f, 0.0f);
CUBLAS_SAFE_CALL(cublasCgemm3mStridedBatched(
handle, (transa ? CUBLAS_OP_T : CUBLAS_OP_N), (transb ? CUBLAS_OP_T : CUBLAS_OP_N), m, n, k, &alpha, a,
(transa ? k : m), m * k, b, (transb ? n : k), n * k, &beta, c, m, m * n, batchCount));
}
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
/**
* @file cublas_helper.h
* @brief Header file for some functions related to cublas
*/
#pragma once
#include <sstream>
#include <string>
#include <cuComplex.h>
#include <cublas_v2.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
/**
* @brief check cuda function running status and throw error str
*/
void check_cuda(cudaError_t result, char const *const func, const char *const file, int const line);
#define CUDA_SAFE_CALL(x) check_cuda((x), #x, __FILE__, __LINE__)
/**
* @brief check cublas function running status and throw error str
*/
void check_cublas(cublasStatus_t result, char const *const func, const char *const file, int const line);
#define CUBLAS_SAFE_CALL(x) check_cublas((x), #x, __FILE__, __LINE__)
/**
* @brief Cuda context init
*/
void cuda_init(cublasHandle_t *cublas_handle);
/**
* @brief Cuda context free
*/
void cuda_free(cublasHandle_t *cublas_handle);
/**
* @brief cublas function of gemm, wrapper of cublasSgemm
* @param handle cublas handle
* @param transa whether matrixA transpose
* @param transb whether matrixB transpose
* @param m m of matrix m*n,n*k
* @param n n of matrix m*n,n*k
* @param k k of matrix m*n,n*k
* @param a input matrixA
* @param b input matrixB
* @param c output matrix
*/
void sgemm(cublasHandle_t handle, int transa, int transb, int m, int n, int k, const float *a, const float *b,
float *c);
/**
* @brief cublas function of gemm, wrapper of cublasCgemm
* @param handle cublas handle
* @param transa whether matrixA transpose
* @param transb whether matrixB transpose
* @param m m of matrix m*n,n*k
* @param n n of matrix m*n,n*k
* @param k k of matrix m*n,n*k
* @param a input matrixA
* @param b input matrixB
* @param c output matrix
*/
void cgemm(cublasHandle_t handle, int transa, int transb, int m, int n, int k, const cuComplex *a, const cuComplex *b,
cuComplex *c);
/**
* @brief cublas function of GemmEx, wrapper of cublasGemmEx
* @param handle cublas handle
* @param transa whether matrixA transpose
* @param transb whether matrixB transpose
* @param m m of matrix m*n,n*k
* @param n n of matrix m*n,n*k
* @param k k of matrix m*n,n*k
* @param a input matrixA
* @param b input matrixB
* @param c output matrix
* @param type matrix type, 'float' or 'half'
* @param use_tensor_core whether use tensor core
*/
void gemmEx(cublasHandle_t handle, int transa, int transb, int m, int n, int k, const void *A, const void *B, void *C,
std::string type, bool use_tensor_core);
/**
* @brief cublas function of gemmStridedBatchedEx, wrapper of cublasGemmStridedBatchedEx
* @param handle cublas handle
* @param transa whether matrixA transpose
* @param transb whether matrixB transpose
* @param m m of matrix m*n,n*k
* @param n n of matrix m*n,n*k
* @param k k of matrix m*n,n*k
* @param a input matrixA
* @param b input matrixB
* @param c output matrix
* @param type matrix type, 'float' or 'half'
* @param use_tensor_core whether use tensor core
* @param batchCount My Param doc
*/
void gemmStridedBatchedEx(cublasHandle_t handle, int transa, int transb, int m, int n, int k, const void *a,
const void *b, void *c, std::string type, bool use_tensor_core, int batchCount);
/**
* @brief cublas function of gemmStridedBatchedEx, wrapper of cublasGemmStridedBatchedEx
* @param handle cublas handle
* @param transa whether matrixA transpose
* @param transb whether matrixB transpose
* @param m m of matrix m*n,n*k
* @param n n of matrix m*n,n*k
* @param k k of matrix m*n,n*k
* @param a input matrixA
* @param b input matrixB
* @param c output matrix
* @param batchCount the count of batch used to compute
*/
void cgemm3mStridedBatched(cublasHandle_t handle, int transa, int transb, int m, int n, int k, const cuComplex *a,
const cuComplex *b, cuComplex *c, int batchCount);
/**
* @brief
* @brief cublas function of sgemmStridedBatched, wrapper of cublasSgemmStridedBatched
* @param handle cublas handle
* @param transa whether matrixA transpose
* @param transb whether matrixB transpose
* @param m m of matrix m*n,n*k
* @param n n of matrix m*n,n*k
* @param k k of matrix m*n,n*k
* @param a input matrixA
* @param b input matrixB
* @param c output matrix
* @param batchCount the count of batch used to compute
*/
void sgemmStridedBatched(cublasHandle_t handle, int transa, int transb, int m, int n, int k, const float *a,
const float *b, float *c, int batchCount);
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
/**
* @file cublas_test.cpp
* @brief Cublas function benchmark will read the params from cmd, and use these params
* to benchmark the wall time of the cublas functions.
*/
#include "cublas_function_helper.h"
/**
* @brief Main function and entry of cublas benchmark
* @details
* params list:
* num_test: test step nums
* warm_up: warm up step nums
* num_in_step: times each step will invoke the function
* config path: the path of 'para_info.json'
* functions supported:
* cublasSgemm
* cublasGemmEx
* cublasSgemmStridedBatched
* cublasGemmStridedBatchedEx
* cublasCgemm
* cublasCgemm3mStridedBatched
* @param argc
* @param argv
* @return int
*/
int main(int argc, char *argv[]) {
try {
// parse arguments from cmd
Options options(argc, argv);
// benchmark each function defined in 'para_info.json'
run_benchmark(options);
} catch (std::exception &e) {
std::cout << "Error: " << e.what() << std::endl;
exit(-1);
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment