Unverified Commit b808135c authored by Yifan Xiong's avatar Yifan Xiong Committed by GitHub
Browse files

Benchmarks - Support tensor core precisions in cublaslt gemm (#492)

Support FP64/TF32/FP16/BF16 in cublaslt (batch) GEMM.
parent 139d4df5
...@@ -21,8 +21,8 @@ def __init__(self, name, parameters=''): ...@@ -21,8 +21,8 @@ def __init__(self, name, parameters=''):
""" """
super().__init__(name, parameters) super().__init__(name, parameters)
self._bin_name = 'cublaslt_fp8_gemm' self._bin_name = 'cublaslt_gemm'
self._in_types = ['fp16', 'fp8e4m3', 'fp8e5m2'] self._in_types = ['fp64', 'fp32', 'fp16', 'bf16', 'fp8e4m3', 'fp8e5m2']
def add_parser_arguments(self): def add_parser_arguments(self):
"""Add the specified arguments.""" """Add the specified arguments."""
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# Licensed under the MIT License. # Licensed under the MIT License.
cmake_minimum_required(VERSION 3.18) cmake_minimum_required(VERSION 3.18)
project(cublaslt_fp8_gemm LANGUAGES CXX) project(cublaslt_gemm LANGUAGES CXX)
find_package(CUDAToolkit QUIET) find_package(CUDAToolkit QUIET)
...@@ -15,8 +15,8 @@ if(CUDAToolkit_FOUND AND NOT CUDAToolkit_VERSION VERSION_LESS 11.8) ...@@ -15,8 +15,8 @@ if(CUDAToolkit_FOUND AND NOT CUDAToolkit_VERSION VERSION_LESS 11.8)
set_target_properties(cublaslt_utils PROPERTIES LINK_FLAGS_RELEASE -s) set_target_properties(cublaslt_utils PROPERTIES LINK_FLAGS_RELEASE -s)
install(TARGETS cublaslt_utils LIBRARY DESTINATION lib) install(TARGETS cublaslt_utils LIBRARY DESTINATION lib)
add_executable(cublaslt_fp8_gemm cublaslt_fp8_gemm.cu) add_executable(cublaslt_gemm cublaslt_gemm.cu)
target_link_libraries(cublaslt_fp8_gemm cublaslt_utils) target_link_libraries(cublaslt_gemm cublaslt_utils)
set_target_properties(cublaslt_fp8_gemm PROPERTIES CUDA_ARCHITECTURES "80;86;90") set_target_properties(cublaslt_gemm PROPERTIES CUDA_ARCHITECTURES "80;86;90")
install(TARGETS cublaslt_fp8_gemm RUNTIME DESTINATION bin) install(TARGETS cublaslt_gemm RUNTIME DESTINATION bin)
endif() endif()
...@@ -10,7 +10,10 @@ ...@@ -10,7 +10,10 @@
#include "cublaslt_utils.h" #include "cublaslt_utils.h"
using fp16 = half; // nv_bfloat16 using fp64 = double;
using fp32 = float;
using fp16 = half;
using bf16 = nv_bfloat16;
using fp8e4m3 = __nv_fp8_e4m3; using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2; using fp8e5m2 = __nv_fp8_e5m2;
...@@ -61,7 +64,7 @@ void process_args(int argc, char **argv, Args *args) { ...@@ -61,7 +64,7 @@ void process_args(int argc, char **argv, Args *args) {
} }
} }
template <typename T> __global__ void init_matrix(T *matrix, const fp16 val, const size_t N) { template <typename T> __global__ void init_matrix(T *matrix, const fp32 val, const size_t N) {
size_t tid = blockIdx.x * blockDim.x + threadIdx.x; size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
for (size_t i = tid; i < N; i += gridDim.x * blockDim.x) { for (size_t i = tid; i < N; i += gridDim.x * blockDim.x) {
matrix[i] = T(val); matrix[i] = T(val);
...@@ -69,8 +72,14 @@ template <typename T> __global__ void init_matrix(T *matrix, const fp16 val, con ...@@ -69,8 +72,14 @@ template <typename T> __global__ void init_matrix(T *matrix, const fp16 val, con
} }
template <typename T> cudaDataType_t get_datatype() { template <typename T> cudaDataType_t get_datatype() {
if (std::is_same<T, fp64>::value)
return CUDA_R_64F;
if (std::is_same<T, fp32>::value)
return CUDA_R_32F;
if (std::is_same<T, fp16>::value) if (std::is_same<T, fp16>::value)
return CUDA_R_16F; return CUDA_R_16F;
if (std::is_same<T, bf16>::value)
return CUDA_R_16BF;
if (std::is_same<T, fp8e4m3>::value) if (std::is_same<T, fp8e4m3>::value)
return CUDA_R_8F_E4M3; return CUDA_R_8F_E4M3;
if (std::is_same<T, fp8e5m2>::value) if (std::is_same<T, fp8e5m2>::value)
...@@ -88,8 +97,8 @@ float timing_matmul_tn(int m, int n, int k, int batch, int warmup, int iter) { ...@@ -88,8 +97,8 @@ float timing_matmul_tn(int m, int n, int k, int batch, int warmup, int iter) {
cudaMalloc(&matrix_b, k * n * std::max(batch, 1) * sizeof(Tb)); cudaMalloc(&matrix_b, k * n * std::max(batch, 1) * sizeof(Tb));
cudaMalloc(&matrix_out, m * n * std::max(batch, 1) * sizeof(Tout)); cudaMalloc(&matrix_out, m * n * std::max(batch, 1) * sizeof(Tout));
init_matrix<Ta><<<216, 1024>>>(matrix_a, static_cast<fp16>(1.f), m * k * std::max(batch, 1)); init_matrix<Ta><<<216, 1024>>>(matrix_a, 1.f, m * k * std::max(batch, 1));
init_matrix<Tb><<<216, 1024>>>(matrix_b, static_cast<fp16>(2.f), k * n * std::max(batch, 1)); init_matrix<Tb><<<216, 1024>>>(matrix_b, 2.f, k * n * std::max(batch, 1));
// init gemm // init gemm
int lda = k, ldb = k, ldd = m; int lda = k, ldb = k, ldd = m;
...@@ -129,7 +138,7 @@ float timing_matmul_tn(int m, int n, int k, int batch, int warmup, int iter) { ...@@ -129,7 +138,7 @@ float timing_matmul_tn(int m, int n, int k, int batch, int warmup, int iter) {
return (time * 1e3 / iter); return (time * 1e3 / iter);
} }
template <typename Ta, typename Tb = Ta, typename Tout = fp16> void run(Args *args) { template <typename Ta, typename Tb = Ta, typename Tout = Ta> void run(Args *args) {
float time_us = timing_matmul_tn<Ta, Tb, Tout>(args->m, args->n, args->k, args->batch, args->warmup, args->iter); float time_us = timing_matmul_tn<Ta, Tb, Tout>(args->m, args->n, args->k, args->batch, args->warmup, args->iter);
// m n k batch time_us tflops // m n k batch time_us tflops
printf("%d\t%d\t%d\t%d\t%f\t%f\n", args->m, args->n, args->k, args->batch, time_us, printf("%d\t%d\t%d\t%d\t%f\t%f\n", args->m, args->n, args->k, args->batch, time_us,
...@@ -140,12 +149,18 @@ int main(int argc, char **argv) { ...@@ -140,12 +149,18 @@ int main(int argc, char **argv) {
Args args; Args args;
process_args(argc, argv, &args); process_args(argc, argv, &args);
if (args.in_type == "fp16") if (args.in_type == "fp64")
run<fp64>(&args);
else if (args.in_type == "fp32")
run<fp32>(&args);
else if (args.in_type == "fp16")
run<fp16>(&args); run<fp16>(&args);
else if (args.in_type == "bf16")
run<bf16>(&args);
else if (args.in_type == "fp8e4m3") else if (args.in_type == "fp8e4m3")
run<fp8e4m3>(&args); run<fp8e4m3, fp8e4m3, fp16>(&args);
else if (args.in_type == "fp8e5m2") else if (args.in_type == "fp8e5m2")
run<fp8e5m2, fp8e4m3>(&args); run<fp8e5m2, fp8e4m3, fp16>(&args);
else else
throw std::invalid_argument("Unknown type " + args.in_type); throw std::invalid_argument("Unknown type " + args.in_type);
......
...@@ -22,7 +22,7 @@ void cublasLtGemm::Setup(int m, int n, int k, int batch, int lda, int ldb, int l ...@@ -22,7 +22,7 @@ void cublasLtGemm::Setup(int m, int n, int k, int batch, int lda, int ldb, int l
) { ) {
cublasLtMatrixLayout_t a_desc = nullptr, b_desc = nullptr, c_desc = nullptr, d_desc = nullptr; cublasLtMatrixLayout_t a_desc = nullptr, b_desc = nullptr, c_desc = nullptr, d_desc = nullptr;
// force c_type // force c_type
cudaDataType_t c_type = CUDA_R_16F; cudaDataType_t c_type = d_type;
// Create matrix descriptors. // Create matrix descriptors.
checkCublasStatus( checkCublasStatus(
cublasLtMatrixLayoutCreate(&a_desc, a_type, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda)); cublasLtMatrixLayoutCreate(&a_desc, a_type, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda));
...@@ -57,10 +57,11 @@ void cublasLtGemm::Setup(int m, int n, int k, int batch, int lda, int ldb, int l ...@@ -57,10 +57,11 @@ void cublasLtGemm::Setup(int m, int n, int k, int batch, int lda, int ldb, int l
d_desc_.reset(d_desc); d_desc_.reset(d_desc);
// default to tf32 except for e5m2 inputs where the config is not supported // default to tf32 except for e5m2 inputs where the config is not supported
cublasComputeType_t gemm_compute_type = cublasComputeType_t gemm_compute_type = CUBLAS_COMPUTE_32F_FAST_TF32;
(a_type == CUDA_R_8F_E5M2 || b_type == CUDA_R_8F_E5M2 || a_type == CUDA_R_8F_E4M3 || b_type == CUDA_R_8F_E4M3) if (a_type == CUDA_R_8F_E5M2 || b_type == CUDA_R_8F_E5M2 || a_type == CUDA_R_8F_E4M3 || b_type == CUDA_R_8F_E4M3)
? CUBLAS_COMPUTE_32F gemm_compute_type = CUBLAS_COMPUTE_32F;
: CUBLAS_COMPUTE_32F_FAST_TF32; if (a_type == CUDA_R_64F || b_type == CUDA_R_64F)
gemm_compute_type = CUBLAS_COMPUTE_64F;
cublasLtMatmulDesc_t op_desc = nullptr; cublasLtMatmulDesc_t op_desc = nullptr;
checkCublasStatus(cublasLtMatmulDescCreate(&op_desc, gemm_compute_type, CUDA_R_32F)); checkCublasStatus(cublasLtMatmulDescCreate(&op_desc, gemm_compute_type, CUDA_R_32F));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment