"git@developer.sourcefind.cn:gaoqiong/pybind11.git" did not exist on "c0fbb02c9fc1fe969f2740253469499385e6dc20"
Unverified Commit 606ff191 authored by Ziyue Yang's avatar Ziyue Yang Committed by GitHub
Browse files

Benchmarks: Microbenchmark - Support different hipblasLt data types in dist_inference (#590)

**Description**
Support different data types in different hipblasLt versions for
dist_inference
parent 2c2096ed
...@@ -31,6 +31,9 @@ else() ...@@ -31,6 +31,9 @@ else()
# link hip device lib # link hip device lib
add_executable(dist_inference dist_inference.cpp) add_executable(dist_inference dist_inference.cpp)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O2 -DROCM_USE_FLOAT16=1") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O2 -DROCM_USE_FLOAT16=1")
if(DEFINED ENV{USE_HIPBLASLT_DATATYPE})
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_HIPBLASLT_DATATYPE=1")
endif()
target_link_libraries(dist_inference MPI::MPI_CXX rccl hipblaslt hip::device) target_link_libraries(dist_inference MPI::MPI_CXX rccl hipblaslt hip::device)
else() else()
message(FATAL_ERROR "No CUDA or ROCm environment found.") message(FATAL_ERROR "No CUDA or ROCm environment found.")
......
...@@ -45,6 +45,13 @@ ...@@ -45,6 +45,13 @@
#include <hipblaslt/hipblaslt.h> #include <hipblaslt/hipblaslt.h>
#include <rccl/rccl.h> #include <rccl/rccl.h>
using cublasLtHalf = hipblasLtHalf; using cublasLtHalf = hipblasLtHalf;
#if defined(USE_HIPBLASLT_DATATYPE)
#define DIST_INF_HIP_DATATYPE_R_16F HIPBLASLT_R_16F
#define DIST_INF_HIP_DATATYPE_R_32F HIPBLASLT_R_32F
#else
#define DIST_INF_HIP_DATATYPE_R_16F HIPBLAS_R_16F
#define DIST_INF_HIP_DATATYPE_R_32F HIPBLAS_R_32F
#endif
#else #else
#include <cublasLt.h> #include <cublasLt.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
...@@ -229,16 +236,16 @@ void TestModel(int64_t m, int64_t n, int64_t k, float alpha, float beta, int32_t ...@@ -229,16 +236,16 @@ void TestModel(int64_t m, int64_t n, int64_t k, float alpha, float beta, int32_t
CHECK_CUBLASLT_ERROR(hipblasLtCreate(&handle)); CHECK_CUBLASLT_ERROR(hipblasLtCreate(&handle));
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matA, HIPBLAS_R_16F, k, n, k)); CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matA, DIST_INF_HIP_DATATYPE_R_16F, k, n, k));
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matB, HIPBLAS_R_16F, m, k, m)); CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matB, DIST_INF_HIP_DATATYPE_R_16F, m, k, m));
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matC, HIPBLAS_R_16F, m, n, m)); CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matC, DIST_INF_HIP_DATATYPE_R_16F, m, n, m));
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matD, HIPBLAS_R_16F, m, n, m)); CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matD, DIST_INF_HIP_DATATYPE_R_16F, m, n, m));
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matE, HIPBLAS_R_16F, k, m, k)); CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matE, DIST_INF_HIP_DATATYPE_R_16F, k, m, k));
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matF, HIPBLAS_R_16F, k, n, k)); CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matF, DIST_INF_HIP_DATATYPE_R_16F, k, n, k));
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matG, HIPBLAS_R_16F, k, n, k)); CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matG, DIST_INF_HIP_DATATYPE_R_16F, k, n, k));
CHECK_CUBLASLT_ERROR(hipblasLtMatmulDescCreate(&matmul1, HIPBLASLT_COMPUTE_F32, HIPBLAS_R_32F)); CHECK_CUBLASLT_ERROR(hipblasLtMatmulDescCreate(&matmul1, HIPBLASLT_COMPUTE_F32, DIST_INF_HIP_DATATYPE_R_32F));
CHECK_CUBLASLT_ERROR(hipblasLtMatmulDescCreate(&matmul2, HIPBLASLT_COMPUTE_F32, HIPBLAS_R_32F)); CHECK_CUBLASLT_ERROR(hipblasLtMatmulDescCreate(&matmul2, HIPBLASLT_COMPUTE_F32, DIST_INF_HIP_DATATYPE_R_32F));
hipblasOperation_t trans = HIPBLAS_OP_N; hipblasOperation_t trans = HIPBLAS_OP_N;
CHECK_CUBLASLT_ERROR( CHECK_CUBLASLT_ERROR(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment