Unverified Commit bfb625b7 authored by Ziyue Yang's avatar Ziyue Yang Committed by GitHub
Browse files

Benchmarks: Microbenchmark - Adapt to hipblasLt data type changes (#603)

**Description**
Adapt to hipblasLt data type changes.
parent b85f6851
...@@ -33,6 +33,11 @@ else() ...@@ -33,6 +33,11 @@ else()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O2 -DROCM_USE_FLOAT16=1") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O2 -DROCM_USE_FLOAT16=1")
if(DEFINED ENV{USE_HIPBLASLT_DATATYPE}) if(DEFINED ENV{USE_HIPBLASLT_DATATYPE})
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_HIPBLASLT_DATATYPE=1") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_HIPBLASLT_DATATYPE=1")
elseif(DEFINED ENV{USE_HIP_DATATYPE})
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_HIP_DATATYPE=1")
endif()
if(DEFINED ENV{USE_HIPBLAS_COMPUTETYPE})
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_HIPBLAS_COMPUTETYPE=1")
endif() endif()
target_link_libraries(dist_inference MPI::MPI_CXX rccl hipblaslt hip::device) target_link_libraries(dist_inference MPI::MPI_CXX rccl hipblaslt hip::device)
else() else()
......
...@@ -48,10 +48,18 @@ using cublasLtHalf = hipblasLtHalf; ...@@ -48,10 +48,18 @@ using cublasLtHalf = hipblasLtHalf;
#if defined(USE_HIPBLASLT_DATATYPE) #if defined(USE_HIPBLASLT_DATATYPE)
#define DIST_INF_HIP_DATATYPE_R_16F HIPBLASLT_R_16F #define DIST_INF_HIP_DATATYPE_R_16F HIPBLASLT_R_16F
#define DIST_INF_HIP_DATATYPE_R_32F HIPBLASLT_R_32F #define DIST_INF_HIP_DATATYPE_R_32F HIPBLASLT_R_32F
#elif defined(USE_HIP_DATATYPE)
#define DIST_INF_HIP_DATATYPE_R_16F HIP_R_16F
#define DIST_INF_HIP_DATATYPE_R_32F HIP_R_32F
#else #else
#define DIST_INF_HIP_DATATYPE_R_16F HIPBLAS_R_16F #define DIST_INF_HIP_DATATYPE_R_16F HIPBLAS_R_16F
#define DIST_INF_HIP_DATATYPE_R_32F HIPBLAS_R_32F #define DIST_INF_HIP_DATATYPE_R_32F HIPBLAS_R_32F
#endif #endif
#if defined(USE_HIPBLAS_COMPUTETYPE)
#define DIST_INF_HIP_COMPUTETYPE_F32 HIPBLAS_COMPUTE_32F
#else
#define DIST_INF_HIP_COMPUTETYPE_F32 HIPBLASLT_COMPUTE_F32
#endif
#else #else
#include <cublasLt.h> #include <cublasLt.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
...@@ -244,8 +252,10 @@ void TestModel(int64_t m, int64_t n, int64_t k, float alpha, float beta, int32_t ...@@ -244,8 +252,10 @@ void TestModel(int64_t m, int64_t n, int64_t k, float alpha, float beta, int32_t
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matF, DIST_INF_HIP_DATATYPE_R_16F, k, n, k)); CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matF, DIST_INF_HIP_DATATYPE_R_16F, k, n, k));
CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matG, DIST_INF_HIP_DATATYPE_R_16F, k, n, k)); CHECK_CUBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matG, DIST_INF_HIP_DATATYPE_R_16F, k, n, k));
CHECK_CUBLASLT_ERROR(hipblasLtMatmulDescCreate(&matmul1, HIPBLASLT_COMPUTE_F32, DIST_INF_HIP_DATATYPE_R_32F)); CHECK_CUBLASLT_ERROR(
CHECK_CUBLASLT_ERROR(hipblasLtMatmulDescCreate(&matmul2, HIPBLASLT_COMPUTE_F32, DIST_INF_HIP_DATATYPE_R_32F)); hipblasLtMatmulDescCreate(&matmul1, DIST_INF_HIP_COMPUTETYPE_F32, DIST_INF_HIP_DATATYPE_R_32F));
CHECK_CUBLASLT_ERROR(
hipblasLtMatmulDescCreate(&matmul2, DIST_INF_HIP_COMPUTETYPE_F32, DIST_INF_HIP_DATATYPE_R_32F));
hipblasOperation_t trans = HIPBLAS_OP_N; hipblasOperation_t trans = HIPBLAS_OP_N;
CHECK_CUBLASLT_ERROR( CHECK_CUBLASLT_ERROR(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment