Unverified Commit 606ff191 authored by Ziyue Yang's avatar Ziyue Yang Committed by GitHub
Browse files

Benchmarks: Microbenchmark - Support different hipblasLt data types in dist_inference (#590)

**Description**
Support different data types in different hipblasLt versions for
dist_inference
parent 2c2096ed
......@@ -31,6 +31,9 @@ else()
# link hip device lib
add_executable(dist_inference dist_inference.cpp)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O2 -DROCM_USE_FLOAT16=1")
if(DEFINED ENV{USE_HIPBLASLT_DATATYPE})
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_HIPBLASLT_DATATYPE=1")
endif()
target_link_libraries(dist_inference MPI::MPI_CXX rccl hipblaslt hip::device)
else()
message(FATAL_ERROR "No CUDA or ROCm environment found.")
......
......@@ -45,6 +45,13 @@
#include <hipblaslt/hipblaslt.h>
#include <rccl/rccl.h>
using cublasLtHalf = hipblasLtHalf;
#if defined(USE_HIPBLASLT_DATATYPE)
#define DIST_INF_HIP_DATATYPE_R_16F HIPBLASLT_R_16F
#define DIST_INF_HIP_DATATYPE_R_32F HIPBLASLT_R_32F
#else
#define DIST_INF_HIP_DATATYPE_R_16F HIPBLAS_R_16F
#define DIST_INF_HIP_DATATYPE_R_32F HIPBLAS_R_32F
#endif
#else
#include <cublasLt.h>
#include <cuda_fp16.h>
......@@ -229,16 +236,16 @@ void TestModel(int64_t m, int64_t n, int64_t k, float alpha, float beta, int32_t
CHECK_CUBLASLT_ERROR(hipblasLtCreate(&handle));
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matA, HIPBLAS_R_16F, k, n, k));
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matB, HIPBLAS_R_16F, m, k, m));
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matC, HIPBLAS_R_16F, m, n, m));
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matD, HIPBLAS_R_16F, m, n, m));
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matE, HIPBLAS_R_16F, k, m, k));
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matF, HIPBLAS_R_16F, k, n, k));
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matG, HIPBLAS_R_16F, k, n, k));
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matA, DIST_INF_HIP_DATATYPE_R_16F, k, n, k));
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matB, DIST_INF_HIP_DATATYPE_R_16F, m, k, m));
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matC, DIST_INF_HIP_DATATYPE_R_16F, m, n, m));
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matD, DIST_INF_HIP_DATATYPE_R_16F, m, n, m));
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matE, DIST_INF_HIP_DATATYPE_R_16F, k, m, k));
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matF, DIST_INF_HIP_DATATYPE_R_16F, k, n, k));
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matG, DIST_INF_HIP_DATATYPE_R_16F, k, n, k));
CHECK_CUBLASLT_ERROR(hipblasLtMatmulDescCreate(&matmul1, HIPBLASLT_COMPUTE_F32, HIPBLAS_R_32F));
CHECK_CUBLASLT_ERROR(hipblasLtMatmulDescCreate(&matmul2, HIPBLASLT_COMPUTE_F32, HIPBLAS_R_32F));
CHECK_CUBLASLT_ERROR(hipblasLtMatmulDescCreate(&matmul1, HIPBLASLT_COMPUTE_F32, DIST_INF_HIP_DATATYPE_R_32F));
CHECK_CUBLASLT_ERROR(hipblasLtMatmulDescCreate(&matmul2, HIPBLASLT_COMPUTE_F32, DIST_INF_HIP_DATATYPE_R_32F));
hipblasOperation_t trans = HIPBLAS_OP_N;
CHECK_CUBLASLT_ERROR(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment