Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
tsoc
superbenchmark
Commits
bfb625b7
Unverified
Commit
bfb625b7
authored
Dec 22, 2023
by
Ziyue Yang
Committed by
GitHub
Dec 22, 2023
Browse files
Benchmarks: Microbenchmark - Adapt to hipblasLt data type changes (#603)
**Description** Adapt to hipblasLt data type changes.
parent
b85f6851
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
17 additions
and
2 deletions
+17
-2
superbench/benchmarks/micro_benchmarks/dist_inference_cpp/CMakeLists.txt
...hmarks/micro_benchmarks/dist_inference_cpp/CMakeLists.txt
+5
-0
superbench/benchmarks/micro_benchmarks/dist_inference_cpp/dist_inference.cu
...rks/micro_benchmarks/dist_inference_cpp/dist_inference.cu
+12
-2
No files found.
superbench/benchmarks/micro_benchmarks/dist_inference_cpp/CMakeLists.txt
View file @
bfb625b7
...
@@ -33,6 +33,11 @@ else()
...
@@ -33,6 +33,11 @@ else()
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-O2 -DROCM_USE_FLOAT16=1"
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-O2 -DROCM_USE_FLOAT16=1"
)
if
(
DEFINED ENV{USE_HIPBLASLT_DATATYPE}
)
if
(
DEFINED ENV{USE_HIPBLASLT_DATATYPE}
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-DUSE_HIPBLASLT_DATATYPE=1"
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-DUSE_HIPBLASLT_DATATYPE=1"
)
elseif
(
DEFINED ENV{USE_HIP_DATATYPE}
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-DUSE_HIP_DATATYPE=1"
)
endif
()
if
(
DEFINED ENV{USE_HIPBLAS_COMPUTETYPE}
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-DUSE_HIPBLAS_COMPUTETYPE=1"
)
endif
()
endif
()
target_link_libraries
(
dist_inference MPI::MPI_CXX rccl hipblaslt hip::device
)
target_link_libraries
(
dist_inference MPI::MPI_CXX rccl hipblaslt hip::device
)
else
()
else
()
...
...
superbench/benchmarks/micro_benchmarks/dist_inference_cpp/dist_inference.cu
View file @
bfb625b7
...
@@ -48,10 +48,18 @@ using cublasLtHalf = hipblasLtHalf;
...
@@ -48,10 +48,18 @@ using cublasLtHalf = hipblasLtHalf;
#if defined(USE_HIPBLASLT_DATATYPE)
#if defined(USE_HIPBLASLT_DATATYPE)
#define DIST_INF_HIP_DATATYPE_R_16F HIPBLASLT_R_16F
#define DIST_INF_HIP_DATATYPE_R_16F HIPBLASLT_R_16F
#define DIST_INF_HIP_DATATYPE_R_32F HIPBLASLT_R_32F
#define DIST_INF_HIP_DATATYPE_R_32F HIPBLASLT_R_32F
#elif defined(USE_HIP_DATATYPE)
#define DIST_INF_HIP_DATATYPE_R_16F HIP_R_16F
#define DIST_INF_HIP_DATATYPE_R_32F HIP_R_32F
#else
#else
#define DIST_INF_HIP_DATATYPE_R_16F HIPBLAS_R_16F
#define DIST_INF_HIP_DATATYPE_R_16F HIPBLAS_R_16F
#define DIST_INF_HIP_DATATYPE_R_32F HIPBLAS_R_32F
#define DIST_INF_HIP_DATATYPE_R_32F HIPBLAS_R_32F
#endif
#endif
#if defined(USE_HIPBLAS_COMPUTETYPE)
#define DIST_INF_HIP_COMPUTETYPE_F32 HIPBLAS_COMPUTE_32F
#else
#define DIST_INF_HIP_COMPUTETYPE_F32 HIPBLASLT_COMPUTE_F32
#endif
#else
#else
#include <cublasLt.h>
#include <cublasLt.h>
#include <cuda_fp16.h>
#include <cuda_fp16.h>
...
@@ -244,8 +252,10 @@ void TestModel(int64_t m, int64_t n, int64_t k, float alpha, float beta, int32_t
...
@@ -244,8 +252,10 @@ void TestModel(int64_t m, int64_t n, int64_t k, float alpha, float beta, int32_t
CHECK_CUBLASLT_ERROR
(
hipblasLtMatrixLayoutCreate
(
&
matF
,
DIST_INF_HIP_DATATYPE_R_16F
,
k
,
n
,
k
));
CHECK_CUBLASLT_ERROR
(
hipblasLtMatrixLayoutCreate
(
&
matF
,
DIST_INF_HIP_DATATYPE_R_16F
,
k
,
n
,
k
));
CHECK_CUBLASLT_ERROR
(
hipblasLtMatrixLayoutCreate
(
&
matG
,
DIST_INF_HIP_DATATYPE_R_16F
,
k
,
n
,
k
));
CHECK_CUBLASLT_ERROR
(
hipblasLtMatrixLayoutCreate
(
&
matG
,
DIST_INF_HIP_DATATYPE_R_16F
,
k
,
n
,
k
));
CHECK_CUBLASLT_ERROR
(
hipblasLtMatmulDescCreate
(
&
matmul1
,
HIPBLASLT_COMPUTE_F32
,
DIST_INF_HIP_DATATYPE_R_32F
));
CHECK_CUBLASLT_ERROR
(
CHECK_CUBLASLT_ERROR
(
hipblasLtMatmulDescCreate
(
&
matmul2
,
HIPBLASLT_COMPUTE_F32
,
DIST_INF_HIP_DATATYPE_R_32F
));
hipblasLtMatmulDescCreate
(
&
matmul1
,
DIST_INF_HIP_COMPUTETYPE_F32
,
DIST_INF_HIP_DATATYPE_R_32F
));
CHECK_CUBLASLT_ERROR
(
hipblasLtMatmulDescCreate
(
&
matmul2
,
DIST_INF_HIP_COMPUTETYPE_F32
,
DIST_INF_HIP_DATATYPE_R_32F
));
hipblasOperation_t
trans
=
HIPBLAS_OP_N
;
hipblasOperation_t
trans
=
HIPBLAS_OP_N
;
CHECK_CUBLASLT_ERROR
(
CHECK_CUBLASLT_ERROR
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment