"vscode:/vscode.git/clone" did not exist on "ba6f79a75e65610871fd5139311817642292085c"
Unverified Commit cd167e49 authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

Compile for gfx908 and gfx90a (#130)

* adding compilation for multiple targets

* fix build

* clean

* update Jekinsfile

* update readme

* update Jenkins

* use ck::half_t instead of ushort for bf16

* rename enum classes

* clean

* rename

* clean
parent ecf337ba
...@@ -20,7 +20,7 @@ using S = ck::Sequence<Is...>; ...@@ -20,7 +20,7 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// Compilation parameters for a[k, m] * b[k, n] = c[m, n] // Compilation parameters for a[k, m] * b[k, n] = c[m, n]
using device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances = std::tuple< using device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances = std::tuple<
......
...@@ -20,7 +20,7 @@ using S = ck::Sequence<Is...>; ...@@ -20,7 +20,7 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// Compilation parameters for a[k, m] * b[n, k] = c[m, n] // Compilation parameters for a[k, m] * b[n, k] = c[m, n]
using device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances = std::tuple< using device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances = std::tuple<
......
...@@ -20,7 +20,7 @@ using S = ck::Sequence<Is...>; ...@@ -20,7 +20,7 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization_t::MNPadding; static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n] // Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances = std::tuple< using device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances = std::tuple<
......
...@@ -20,7 +20,7 @@ using S = ck::Sequence<Is...>; ...@@ -20,7 +20,7 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// Compilation parameters for a[m, k] * b[n, k] = c[m, n] // Compilation parameters for a[m, k] * b[n, k] = c[m, n]
using device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances = std::tuple< using device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances = std::tuple<
......
...@@ -23,7 +23,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; ...@@ -23,7 +23,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReduceSum = ck::tensor_operation::element_wise::ReduceSum; using ReduceSum = ck::tensor_operation::element_wise::ReduceSum;
using ReduceSquareSum = ck::tensor_operation::element_wise::ReduceSquareSum; using ReduceSquareSum = ck::tensor_operation::element_wise::ReduceSquareSum;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// c[m, n] = a[k, m] * b[k, n] // c[m, n] = a[k, m] * b[k, n]
// d0[m] = reduce0(c[m, n]) // d0[m] = reduce0(c[m, n])
......
...@@ -23,7 +23,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; ...@@ -23,7 +23,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReduceSum = ck::tensor_operation::element_wise::ReduceSum; using ReduceSum = ck::tensor_operation::element_wise::ReduceSum;
using ReduceSquareSum = ck::tensor_operation::element_wise::ReduceSquareSum; using ReduceSquareSum = ck::tensor_operation::element_wise::ReduceSquareSum;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// c[m, n] = a[k, m] * b[n, k] // c[m, n] = a[k, m] * b[n, k]
// d0[m] = reduce0(c[m, n]) // d0[m] = reduce0(c[m, n])
......
...@@ -23,7 +23,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; ...@@ -23,7 +23,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReduceSum = ck::tensor_operation::element_wise::ReduceSum; using ReduceSum = ck::tensor_operation::element_wise::ReduceSum;
using ReduceSquareSum = ck::tensor_operation::element_wise::ReduceSquareSum; using ReduceSquareSum = ck::tensor_operation::element_wise::ReduceSquareSum;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// c[m, n] = a[m, k] * b[n, k] // c[m, n] = a[m, k] * b[n, k]
// d0[m] = reduce0(c[m, n]) // d0[m] = reduce0(c[m, n])
......
...@@ -23,7 +23,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; ...@@ -23,7 +23,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReduceSum = ck::tensor_operation::element_wise::ReduceSum; using ReduceSum = ck::tensor_operation::element_wise::ReduceSum;
using ReduceSquareSum = ck::tensor_operation::element_wise::ReduceSquareSum; using ReduceSquareSum = ck::tensor_operation::element_wise::ReduceSquareSum;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// c[m, n] = a[m, k] * b[n, k] // c[m, n] = a[m, k] * b[n, k]
// d0[m] = reduce0(c[m, n]) // d0[m] = reduce0(c[m, n])
......
...@@ -20,7 +20,7 @@ using S = ck::Sequence<Is...>; ...@@ -20,7 +20,7 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// Compilation parameters for a[k, m] * b[k, n] = c[m, n] // Compilation parameters for a[k, m] * b[k, n] = c[m, n]
using device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances = std::tuple< using device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances = std::tuple<
......
...@@ -20,7 +20,7 @@ using S = ck::Sequence<Is...>; ...@@ -20,7 +20,7 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// Compilation parameters for a[k, m] * b[n, k] = c[m, n] // Compilation parameters for a[k, m] * b[n, k] = c[m, n]
using device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances = std::tuple< using device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances = std::tuple<
......
...@@ -20,7 +20,7 @@ using S = ck::Sequence<Is...>; ...@@ -20,7 +20,7 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n] // Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances = std::tuple< using device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances = std::tuple<
......
...@@ -20,8 +20,8 @@ using S = ck::Sequence<Is...>; ...@@ -20,8 +20,8 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization_t::MNPadding; static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
// Compilation parameters for a[m, k] * b[n, k] = c[m, n] // Compilation parameters for a[m, k] * b[n, k] = c[m, n]
using device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances = std::tuple< using device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances = std::tuple<
......
## Docker script
```bash
docker run \
-it \
--rm \
--privileged \
--group-add sudo \
-w /root/workspace \
-v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \
rocm/tensorflow:rocm4.3.1-tf2.6-dev \
/bin/bash
```
## Build ```ckProfiler```
```bash
mkdir build && cd build
```
```bash
# Need to Specify target ID, example below is gfx908
cmake \
-D BUILD_DEV=OFF \
-D CMAKE_BUILD_TYPE=Release \
-D CMAKE_CXX_FLAGS="-DCK_AMD_GPU_GFX908 --amdgpu-target=gfx908 -O3 " \
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-D CMAKE_PREFIX_PATH=/opt/rocm \
..
```
```bash
make -j ckProfiler
```
## Profile GEMM kernels ## Profile GEMM kernels
```bash ```bash
#arg1: tensor operation (gemm=GEMM) #arg1: tensor operation (gemm=GEMM)
...@@ -42,8 +9,8 @@ cmake \ ...@@ -42,8 +9,8 @@ cmake \
#arg7: run kernel # of times (>1) #arg7: run kernel # of times (>1)
#arg8 to 13: M, N, K, StrideA, StrideB, StrideC #arg8 to 13: M, N, K, StrideA, StrideB, StrideC
##################### op datatype layout verify init log repeat M___ N___ K___ StrideA StrideB StrideC ################ op datatype layout verify init log repeat M___ N___ K___ StrideA StrideB StrideC
./profiler/ckProfiler gemm 1 1 1 1 0 5 3840 4096 4096 4096 4096 4096 ./bin/ckProfiler gemm 1 1 1 1 0 5 3840 4096 4096 4096 4096 4096
``` ```
Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16) Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16)
...@@ -55,7 +22,7 @@ c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} ...@@ -55,7 +22,7 @@ c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
Best Perf: 1.1933 ms, 107.977 TFlops, 79.0848 GB/s Best Perf: 1.1933 ms, 107.977 TFlops, 79.0848 GB/s
``` ```
## Profile forward convolution kernels ## Profile 2d forward convolution kernels
```bash ```bash
#arg1: tensor operation (conv=Convolution) #arg1: tensor operation (conv=Convolution)
#arg2: data type (0=fp32, 1=fp16) #arg2: data type (0=fp32, 1=fp16)
...@@ -67,8 +34,8 @@ Best Perf: 1.1933 ms, 107.977 TFlops, 79.0848 GB/s ...@@ -67,8 +34,8 @@ Best Perf: 1.1933 ms, 107.977 TFlops, 79.0848 GB/s
#arg8: print matrix value (0=no, 1=yes) #arg8: print matrix value (0=no, 1=yes)
#arg9: run kernel # of times (>1) #arg9: run kernel # of times (>1)
#arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx #arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx
##################### op datatype in_layout wei_layout out_layout verify init log repeat N__ K___ C___ Y X Hi__ Wi__ Strides Dilations LeftPads RightPads ################ op datatype in_layout wei_layout out_layout verify init log repeat N__ K___ C___ Y X Hi__ Wi__ Strides Dilations LeftPads RightPads
./profiler/ckProfiler conv_fwd 1 1 1 1 1 1 0 5 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 ./bin/ckProfiler conv2d_fwd 1 1 1 1 1 1 0 5 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1
``` ```
Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16) Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16)
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using F32 = float;
using BF16 = ushort; using BF16 = ck::bhalf_t;
using INT8 = int8_t; using INT8 = int8_t;
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
......
...@@ -64,9 +64,9 @@ template <typename DescriptionType> ...@@ -64,9 +64,9 @@ template <typename DescriptionType>
bool description_match(const DescriptionType& description, bool description_match(const DescriptionType& description,
int Rank, int Rank,
const std::vector<int>& reduceDims, const std::vector<int>& reduceDims,
ReduceTensorOp_t ReduceOpId, ReduceTensorOp ReduceOpId,
NanPropagation_t NanOpt, NanPropagation NanOpt,
ReduceTensorIndices_t IndicesOpt) ReduceTensorIndices IndicesOpt)
{ {
if(description.Rank_ != Rank || description.ReduceOpId_ != static_cast<int>(ReduceOpId) || if(description.Rank_ != Rank || description.ReduceOpId_ != static_cast<int>(ReduceOpId) ||
description.NanOpt_ != static_cast<int>(NanOpt) || description.NanOpt_ != static_cast<int>(NanOpt) ||
...@@ -148,9 +148,9 @@ template <typename InDataType, ...@@ -148,9 +148,9 @@ template <typename InDataType,
typename OutDataType, typename OutDataType,
int Rank, int Rank,
int NumReduceDim, int NumReduceDim,
ReduceTensorOp_t ReduceOpId, ReduceTensorOp ReduceOpId,
NanPropagation_t NanOpt, NanPropagation NanOpt,
ReduceTensorIndices_t IndicesOpt> ReduceTensorIndices IndicesOpt>
void profile_reduce_impl_impl(bool do_verification, void profile_reduce_impl_impl(bool do_verification,
int init_method, int init_method,
bool do_log, bool do_log,
...@@ -166,17 +166,17 @@ void profile_reduce_impl_impl(bool do_verification, ...@@ -166,17 +166,17 @@ void profile_reduce_impl_impl(bool do_verification,
using namespace ck::host_reduce; using namespace ck::host_reduce;
constexpr bool op_support_indices = constexpr bool op_support_indices =
(ReduceOpId == ReduceTensorOp_t::MIN || ReduceOpId == ReduceTensorOp_t::MAX || (ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX ||
ReduceOpId == ReduceTensorOp_t::AMAX); ReduceOpId == ReduceTensorOp::AMAX);
constexpr bool NeedIndices = constexpr bool NeedIndices =
(op_support_indices && (IndicesOpt != ReduceTensorIndices_t::NO_INDICES)); (op_support_indices && (IndicesOpt != ReduceTensorIndices::NO_INDICES));
constexpr bool PropagateNan = (NanOpt == NanPropagation_t::PROPAGATE_NAN); constexpr bool PropagateNan = (NanOpt == NanPropagation::PROPAGATE_NAN);
constexpr bool out_support_atomic_add = std::is_same<OutDataType, float>::value; constexpr bool out_support_atomic_add = std::is_same<OutDataType, float>::value;
constexpr bool op_support_atomic_add = constexpr bool op_support_atomic_add =
!op_support_indices && ReduceOpId != ReduceTensorOp_t::NORM2; !op_support_indices && ReduceOpId != ReduceTensorOp::NORM2;
constexpr bool use_atomic_add = (out_support_atomic_add && op_support_atomic_add); constexpr bool use_atomic_add = (out_support_atomic_add && op_support_atomic_add);
// 1) If InDataType is half_t, must use half_t as AccDataType for indexable reduction operations // 1) If InDataType is half_t, must use half_t as AccDataType for indexable reduction operations
...@@ -194,7 +194,7 @@ void profile_reduce_impl_impl(bool do_verification, ...@@ -194,7 +194,7 @@ void profile_reduce_impl_impl(bool do_verification,
// 1) The indices can only be used when the reduction operation is indexable // 1) The indices can only be used when the reduction operation is indexable
constexpr bool invalid_reduce_3 = constexpr bool invalid_reduce_3 =
(!op_support_indices && IndicesOpt != ReduceTensorIndices_t::NO_INDICES); (!op_support_indices && IndicesOpt != ReduceTensorIndices::NO_INDICES);
// 1) If InDataType is int8_t, must use int8_t as AccDataType for indexable reduction operations // 1) If InDataType is int8_t, must use int8_t as AccDataType for indexable reduction operations
// 2) If InDataType is int8_t, must use int32_t as AccDataType for non-indexable reduction // 2) If InDataType is int8_t, must use int32_t as AccDataType for non-indexable reduction
...@@ -207,8 +207,8 @@ void profile_reduce_impl_impl(bool do_verification, ...@@ -207,8 +207,8 @@ void profile_reduce_impl_impl(bool do_verification,
// 1) If InDataType is int8_t, the supported operation must be either indexable operations or // 1) If InDataType is int8_t, the supported operation must be either indexable operations or
// ADD/AVG // ADD/AVG
constexpr bool invalid_reduce_5 = std::is_same<InDataType, int8_t>::value && constexpr bool invalid_reduce_5 = std::is_same<InDataType, int8_t>::value &&
(!op_support_indices && ReduceOpId != ReduceTensorOp_t::ADD && (!op_support_indices && ReduceOpId != ReduceTensorOp::ADD &&
ReduceOpId != ReduceTensorOp_t::AVG); ReduceOpId != ReduceTensorOp::AVG);
// 1) If InDataType is bhalf_t, must use float as AccDataType for all reduction operations // 1) If InDataType is bhalf_t, must use float as AccDataType for all reduction operations
constexpr bool invalid_reduce_6 = constexpr bool invalid_reduce_6 =
...@@ -631,9 +631,9 @@ void profile_reduce_impl(bool do_verification, ...@@ -631,9 +631,9 @@ void profile_reduce_impl(bool do_verification,
int nrepeat, int nrepeat,
const std::vector<size_t>& inLengths, const std::vector<size_t>& inLengths,
const std::vector<int>& reduceDims, const std::vector<int>& reduceDims,
ReduceTensorOp_t ReduceOpId, ReduceTensorOp ReduceOpId,
NanPropagation_t NanOpt, NanPropagation NanOpt,
ReduceTensorIndices_t IndicesOpt, ReduceTensorIndices IndicesOpt,
float alpha, float alpha,
float beta) float beta)
{ {
...@@ -659,9 +659,9 @@ void profile_reduce_impl(bool do_verification, ...@@ -659,9 +659,9 @@ void profile_reduce_impl(bool do_verification,
OutDataType, OutDataType,
descType::Rank_, descType::Rank_,
descType::NumReduceDim_, descType::NumReduceDim_,
static_cast<ReduceTensorOp_t>(descType::ReduceOpId_), static_cast<ReduceTensorOp>(descType::ReduceOpId_),
static_cast<NanPropagation_t>(descType::NanOpt_), static_cast<NanPropagation>(descType::NanOpt_),
static_cast<ReduceTensorIndices_t>(descType::IndicesOpt_)>( static_cast<ReduceTensorIndices>(descType::IndicesOpt_)>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
int profile_batched_gemm_reduce(int argc, char* argv[]) int profile_batched_gemm_reduce(int argc, char* argv[])
{ {
enum struct GemmMatrixLayout_t enum struct GemmMatrixLayout
{ {
MK_KN_MN, // 0 MK_KN_MN, // 0
MK_NK_MN, // 1 MK_NK_MN, // 1
...@@ -17,7 +17,7 @@ int profile_batched_gemm_reduce(int argc, char* argv[]) ...@@ -17,7 +17,7 @@ int profile_batched_gemm_reduce(int argc, char* argv[])
KM_NK_MN, // 3 KM_NK_MN, // 3
}; };
enum struct GemmReduceDataType_t enum struct GemmReduceDataType
{ {
F32_F32_F32_F32_F32, // 0 F32_F32_F32_F32_F32, // 0
F16_F16_F16_F32_F32, // 1 F16_F16_F16_F32_F32, // 1
...@@ -40,8 +40,8 @@ int profile_batched_gemm_reduce(int argc, char* argv[]) ...@@ -40,8 +40,8 @@ int profile_batched_gemm_reduce(int argc, char* argv[])
exit(1); exit(1);
} }
const auto data_type = static_cast<GemmReduceDataType_t>(std::stoi(argv[2])); const auto data_type = static_cast<GemmReduceDataType>(std::stoi(argv[2]));
const auto layout = static_cast<GemmMatrixLayout_t>(std::stoi(argv[3])); const auto layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
const bool do_verification = std::stoi(argv[4]); const bool do_verification = std::stoi(argv[4]);
const int init_method = std::stoi(argv[5]); const int init_method = std::stoi(argv[5]);
const bool do_log = std::stoi(argv[6]); const bool do_log = std::stoi(argv[6]);
...@@ -57,8 +57,7 @@ int profile_batched_gemm_reduce(int argc, char* argv[]) ...@@ -57,8 +57,7 @@ int profile_batched_gemm_reduce(int argc, char* argv[])
const int BatchCount = std::stoi(argv[14]); const int BatchCount = std::stoi(argv[14]);
if(data_type == GemmReduceDataType_t::F16_F16_F16_F32_F32 && if(data_type == GemmReduceDataType::F16_F16_F16_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
layout == GemmMatrixLayout_t::MK_KN_MN)
{ {
ck::profiler::profile_batched_gemm_reduce_impl<ck::half_t, ck::profiler::profile_batched_gemm_reduce_impl<ck::half_t,
ck::half_t, ck::half_t,
...@@ -79,8 +78,8 @@ int profile_batched_gemm_reduce(int argc, char* argv[]) ...@@ -79,8 +78,8 @@ int profile_batched_gemm_reduce(int argc, char* argv[])
(StrideC < 0) ? N : StrideC, (StrideC < 0) ? N : StrideC,
BatchCount); BatchCount);
} }
else if(data_type == GemmReduceDataType_t::F16_F16_F16_F32_F32 && else if(data_type == GemmReduceDataType::F16_F16_F16_F32_F32 &&
layout == GemmMatrixLayout_t::MK_NK_MN) layout == GemmMatrixLayout::MK_NK_MN)
{ {
ck::profiler::profile_batched_gemm_reduce_impl<ck::half_t, ck::profiler::profile_batched_gemm_reduce_impl<ck::half_t,
ck::half_t, ck::half_t,
...@@ -101,8 +100,8 @@ int profile_batched_gemm_reduce(int argc, char* argv[]) ...@@ -101,8 +100,8 @@ int profile_batched_gemm_reduce(int argc, char* argv[])
(StrideC < 0) ? N : StrideC, (StrideC < 0) ? N : StrideC,
BatchCount); BatchCount);
} }
else if(data_type == GemmReduceDataType_t::F16_F16_F16_F32_F32 && else if(data_type == GemmReduceDataType::F16_F16_F16_F32_F32 &&
layout == GemmMatrixLayout_t::KM_KN_MN) layout == GemmMatrixLayout::KM_KN_MN)
{ {
ck::profiler::profile_batched_gemm_reduce_impl<ck::half_t, ck::profiler::profile_batched_gemm_reduce_impl<ck::half_t,
ck::half_t, ck::half_t,
...@@ -123,8 +122,8 @@ int profile_batched_gemm_reduce(int argc, char* argv[]) ...@@ -123,8 +122,8 @@ int profile_batched_gemm_reduce(int argc, char* argv[])
(StrideC < 0) ? N : StrideC, (StrideC < 0) ? N : StrideC,
BatchCount); BatchCount);
} }
else if(data_type == GemmReduceDataType_t::F16_F16_F16_F32_F32 && else if(data_type == GemmReduceDataType::F16_F16_F16_F32_F32 &&
layout == GemmMatrixLayout_t::KM_NK_MN) layout == GemmMatrixLayout::KM_NK_MN)
{ {
ck::profiler::profile_batched_gemm_reduce_impl<ck::half_t, ck::profiler::profile_batched_gemm_reduce_impl<ck::half_t,
ck::half_t, ck::half_t,
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include "profile_convnd_bwd_data_impl.hpp" #include "profile_convnd_bwd_data_impl.hpp"
enum ConvDataType enum struct ConvDataType
{ {
F32_F32_F32, // 0 F32_F32_F32, // 0
F16_F16_F16, // 1 F16_F16_F16, // 1
...@@ -15,19 +15,19 @@ enum ConvDataType ...@@ -15,19 +15,19 @@ enum ConvDataType
INT8_INT8_INT8, // 3 INT8_INT8_INT8, // 3
}; };
enum ConvInputLayout enum struct ConvInputLayout
{ {
NCHW, // 0 NCHW, // 0
NHWC, // 1 NHWC, // 1
}; };
enum ConvWeightLayout enum struct ConvWeightLayout
{ {
KCYX, // 0 KCYX, // 0
KYXC, // 1 KYXC, // 1
}; };
enum ConvOutputLayout enum struct ConvOutputLayout
{ {
NKHW, // 0 NKHW, // 0
NHWK, // 1 NHWK, // 1
...@@ -97,10 +97,10 @@ int profile_convnd_bwd_data(int argc, char* argv[], int num_dim_spatial) ...@@ -97,10 +97,10 @@ int profile_convnd_bwd_data(int argc, char* argv[], int num_dim_spatial)
return 1; return 1;
} }
const int data_type = static_cast<ConvDataType>(std::stoi(argv[2])); const auto data_type = static_cast<ConvDataType>(std::stoi(argv[2]));
const int in_layout = static_cast<ConvInputLayout>(std::stoi(argv[3])); const auto in_layout = static_cast<ConvInputLayout>(std::stoi(argv[3]));
const int wei_layout = static_cast<ConvWeightLayout>(std::stoi(argv[4])); const auto wei_layout = static_cast<ConvWeightLayout>(std::stoi(argv[4]));
const int out_layout = static_cast<ConvOutputLayout>(std::stoi(argv[5])); const auto out_layout = static_cast<ConvOutputLayout>(std::stoi(argv[5]));
const bool do_verification = std::stoi(argv[6]); const bool do_verification = std::stoi(argv[6]);
const int init_method = std::stoi(argv[7]); const int init_method = std::stoi(argv[7]);
const bool do_log = std::stoi(argv[8]); const bool do_log = std::stoi(argv[8]);
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
int profile_gemm_reduce(int argc, char* argv[]) int profile_gemm_reduce(int argc, char* argv[])
{ {
enum struct GemmMatrixLayout_t enum struct GemmMatrixLayout
{ {
MK_KN_MN, // 0 MK_KN_MN, // 0
MK_NK_MN, // 1 MK_NK_MN, // 1
...@@ -16,7 +16,7 @@ int profile_gemm_reduce(int argc, char* argv[]) ...@@ -16,7 +16,7 @@ int profile_gemm_reduce(int argc, char* argv[])
KM_NK_MN, // 3 KM_NK_MN, // 3
}; };
enum struct GemmReduceDataType_t enum struct GemmReduceDataType
{ {
F32_F32_F32_F32_F32, // 0 F32_F32_F32_F32_F32, // 0
F16_F16_F16_F32_F32, // 1 F16_F16_F16_F32_F32, // 1
...@@ -39,8 +39,8 @@ int profile_gemm_reduce(int argc, char* argv[]) ...@@ -39,8 +39,8 @@ int profile_gemm_reduce(int argc, char* argv[])
exit(1); exit(1);
} }
const auto data_type = static_cast<GemmReduceDataType_t>(std::stoi(argv[2])); const auto data_type = static_cast<GemmReduceDataType>(std::stoi(argv[2]));
const auto layout = static_cast<GemmMatrixLayout_t>(std::stoi(argv[3])); const auto layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
const bool do_verification = std::stoi(argv[4]); const bool do_verification = std::stoi(argv[4]);
const int init_method = std::stoi(argv[5]); const int init_method = std::stoi(argv[5]);
const bool do_log = std::stoi(argv[6]); const bool do_log = std::stoi(argv[6]);
...@@ -54,8 +54,7 @@ int profile_gemm_reduce(int argc, char* argv[]) ...@@ -54,8 +54,7 @@ int profile_gemm_reduce(int argc, char* argv[])
const int StrideB = std::stoi(argv[12]); const int StrideB = std::stoi(argv[12]);
const int StrideC = std::stoi(argv[13]); const int StrideC = std::stoi(argv[13]);
if(data_type == GemmReduceDataType_t::F16_F16_F16_F32_F32 && if(data_type == GemmReduceDataType::F16_F16_F16_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
layout == GemmMatrixLayout_t::MK_KN_MN)
{ {
ck::profiler::profile_gemm_reduce_impl<ck::half_t, ck::profiler::profile_gemm_reduce_impl<ck::half_t,
ck::half_t, ck::half_t,
...@@ -75,8 +74,8 @@ int profile_gemm_reduce(int argc, char* argv[]) ...@@ -75,8 +74,8 @@ int profile_gemm_reduce(int argc, char* argv[])
(StrideB < 0) ? N : StrideB, (StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC); (StrideC < 0) ? N : StrideC);
} }
else if(data_type == GemmReduceDataType_t::F16_F16_F16_F32_F32 && else if(data_type == GemmReduceDataType::F16_F16_F16_F32_F32 &&
layout == GemmMatrixLayout_t::MK_NK_MN) layout == GemmMatrixLayout::MK_NK_MN)
{ {
ck::profiler::profile_gemm_reduce_impl<ck::half_t, ck::profiler::profile_gemm_reduce_impl<ck::half_t,
ck::half_t, ck::half_t,
...@@ -96,8 +95,8 @@ int profile_gemm_reduce(int argc, char* argv[]) ...@@ -96,8 +95,8 @@ int profile_gemm_reduce(int argc, char* argv[])
(StrideB < 0) ? K : StrideB, (StrideB < 0) ? K : StrideB,
(StrideC < 0) ? N : StrideC); (StrideC < 0) ? N : StrideC);
} }
else if(data_type == GemmReduceDataType_t::F16_F16_F16_F32_F32 && else if(data_type == GemmReduceDataType::F16_F16_F16_F32_F32 &&
layout == GemmMatrixLayout_t::KM_KN_MN) layout == GemmMatrixLayout::KM_KN_MN)
{ {
ck::profiler::profile_gemm_reduce_impl<ck::half_t, ck::profiler::profile_gemm_reduce_impl<ck::half_t,
ck::half_t, ck::half_t,
...@@ -117,8 +116,8 @@ int profile_gemm_reduce(int argc, char* argv[]) ...@@ -117,8 +116,8 @@ int profile_gemm_reduce(int argc, char* argv[])
(StrideB < 0) ? N : StrideB, (StrideB < 0) ? N : StrideB,
(StrideC < 0) ? N : StrideC); (StrideC < 0) ? N : StrideC);
} }
else if(data_type == GemmReduceDataType_t::F16_F16_F16_F32_F32 && else if(data_type == GemmReduceDataType::F16_F16_F16_F32_F32 &&
layout == GemmMatrixLayout_t::KM_NK_MN) layout == GemmMatrixLayout::KM_NK_MN)
{ {
ck::profiler::profile_gemm_reduce_impl<ck::half_t, ck::profiler::profile_gemm_reduce_impl<ck::half_t,
ck::half_t, ck::half_t,
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include <half.hpp> #include <half.hpp>
#include "profile_grouped_gemm_impl.hpp" #include "profile_grouped_gemm_impl.hpp"
enum GemmMatrixLayout enum struct GemmMatrixLayout
{ {
MK_KN_MN, // 0 MK_KN_MN, // 0
MK_NK_MN, // 1 MK_NK_MN, // 1
...@@ -18,7 +18,7 @@ enum GemmMatrixLayout ...@@ -18,7 +18,7 @@ enum GemmMatrixLayout
KM_NK_NM, // 7 KM_NK_NM, // 7
}; };
enum GemmDataType enum struct GemmDataType
{ {
F32_F32_F32, // 0 F32_F32_F32, // 0
F16_F16_F16, // 1 F16_F16_F16, // 1
...@@ -61,8 +61,8 @@ int profile_grouped_gemm(int argc, char* argv[]) ...@@ -61,8 +61,8 @@ int profile_grouped_gemm(int argc, char* argv[])
exit(1); exit(1);
} }
const int data_type = static_cast<GemmDataType>(std::stoi(argv[2])); const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
const int layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3])); const auto layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
const bool do_verification = std::stoi(argv[4]); const bool do_verification = std::stoi(argv[4]);
const int init_method = std::stoi(argv[5]); const int init_method = std::stoi(argv[5]);
const bool do_log = std::stoi(argv[6]); const bool do_log = std::stoi(argv[6]);
......
...@@ -20,9 +20,9 @@ ...@@ -20,9 +20,9 @@
using namespace std; using namespace std;
using ck::NanPropagation_t; using ck::NanPropagation;
using ck::ReduceTensorIndices_t; using ck::ReduceTensorIndices;
using ck::ReduceTensorOp_t; using ck::ReduceTensorOp;
static struct option long_options[] = {{"inLengths", required_argument, nullptr, 'D'}, static struct option long_options[] = {{"inLengths", required_argument, nullptr, 'D'},
{"reduceDims", required_argument, nullptr, 'R'}, {"reduceDims", required_argument, nullptr, 'R'},
...@@ -84,7 +84,7 @@ static std::vector<T> getTypeValuesFromString(const char* cstr_values) ...@@ -84,7 +84,7 @@ static std::vector<T> getTypeValuesFromString(const char* cstr_values)
return (values); return (values);
} }
enum struct appDataType_t enum struct AppDataType
{ {
appHalf = 0, appHalf = 0,
appFloat = 1, appFloat = 1,
...@@ -130,18 +130,18 @@ class AppArgs ...@@ -130,18 +130,18 @@ class AppArgs
std::vector<float> scales; std::vector<float> scales;
ReduceTensorOp_t reduceOp = ReduceTensorOp_t::ADD; ReduceTensorOp reduceOp = ReduceTensorOp::ADD;
appDataType_t compTypeId = appDataType_t::appFloat; AppDataType compTypeId = AppDataType::appFloat;
appDataType_t outTypeId = appDataType_t::appFloat; AppDataType outTypeId = AppDataType::appFloat;
bool compType_assigned = false; bool compType_assigned = false;
bool outType_assigned = false; bool outType_assigned = false;
NanPropagation_t nanOpt = NanPropagation_t::NOT_PROPAGATE_NAN; NanPropagation nanOpt = NanPropagation::NOT_PROPAGATE_NAN;
ReduceTensorIndices_t indicesOpt = ReduceTensorIndices_t::NO_INDICES; ReduceTensorIndices indicesOpt = ReduceTensorIndices::NO_INDICES;
bool do_log = false; bool do_log = false;
bool do_verification = false; bool do_verification = false;
bool do_dumpout = false; bool do_dumpout = false;
int init_method; int init_method;
int nrepeat; int nrepeat;
...@@ -213,33 +213,33 @@ class AppArgs ...@@ -213,33 +213,33 @@ class AppArgs
if(!optarg) if(!optarg)
throw std::runtime_error("Invalid option format!"); throw std::runtime_error("Invalid option format!");
reduceOp = static_cast<ReduceTensorOp_t>(std::atoi(optarg)); reduceOp = static_cast<ReduceTensorOp>(std::atoi(optarg));
break; break;
case 'C': case 'C':
if(!optarg) if(!optarg)
throw std::runtime_error("Invalid option format!"); throw std::runtime_error("Invalid option format!");
compTypeId = static_cast<appDataType_t>(std::atoi(optarg)); compTypeId = static_cast<AppDataType>(std::atoi(optarg));
compType_assigned = true; compType_assigned = true;
break; break;
case 'W': case 'W':
if(!optarg) if(!optarg)
throw std::runtime_error("Invalid option format!"); throw std::runtime_error("Invalid option format!");
outTypeId = static_cast<appDataType_t>(std::atoi(optarg)); outTypeId = static_cast<AppDataType>(std::atoi(optarg));
outType_assigned = true; outType_assigned = true;
break; break;
case 'N': case 'N':
if(!optarg) if(!optarg)
throw std::runtime_error("Invalid option format!"); throw std::runtime_error("Invalid option format!");
nanOpt = static_cast<NanPropagation_t>(std::atoi(optarg)); nanOpt = static_cast<NanPropagation>(std::atoi(optarg));
break; break;
case 'I': case 'I':
if(!optarg) if(!optarg)
throw std::runtime_error("Invalid option format!"); throw std::runtime_error("Invalid option format!");
indicesOpt = static_cast<ReduceTensorIndices_t>(std::atoi(optarg)); indicesOpt = static_cast<ReduceTensorIndices>(std::atoi(optarg));
break; break;
case 'S': case 'S':
if(!optarg) if(!optarg)
...@@ -303,10 +303,10 @@ class AppArgs ...@@ -303,10 +303,10 @@ class AppArgs
scales.push_back(0.0f); scales.push_back(0.0f);
}; };
if(reduceOp == ReduceTensorOp_t::MIN || reduceOp == ReduceTensorOp_t::MAX || if(reduceOp == ReduceTensorOp::MIN || reduceOp == ReduceTensorOp::MAX ||
reduceOp == ReduceTensorOp_t::AMAX) reduceOp == ReduceTensorOp::AMAX)
{ {
if(indicesOpt != ReduceTensorIndices_t::NO_INDICES) if(indicesOpt != ReduceTensorIndices::NO_INDICES)
need_indices = true; need_indices = true;
// for indexable operations, no need to assign compType and outType, just let them be // for indexable operations, no need to assign compType and outType, just let them be
...@@ -333,22 +333,22 @@ int profile_reduce(int argc, char* argv[]) ...@@ -333,22 +333,22 @@ int profile_reduce(int argc, char* argv[])
check_reduce_dims(rank, args.reduceDims); check_reduce_dims(rank, args.reduceDims);
if(args.reduceOp == ReduceTensorOp_t::MUL || args.reduceOp == ReduceTensorOp_t::NORM1) if(args.reduceOp == ReduceTensorOp::MUL || args.reduceOp == ReduceTensorOp::NORM1)
throw std::runtime_error("MUL and NORM1 are not supported by composable kernel!"); throw std::runtime_error("MUL and NORM1 are not supported by composable kernel!");
if(args.use_half) if(args.use_half)
{ {
if(!args.compType_assigned) if(!args.compType_assigned)
args.compTypeId = appDataType_t::appHalf; args.compTypeId = AppDataType::appHalf;
if(args.outType_assigned && if(args.outType_assigned &&
(args.outTypeId != appDataType_t::appHalf && args.outTypeId != appDataType_t::appFloat)) (args.outTypeId != AppDataType::appHalf && args.outTypeId != AppDataType::appFloat))
args.outTypeId = appDataType_t::appFloat; args.outTypeId = AppDataType::appFloat;
if(!args.outType_assigned) if(!args.outType_assigned)
args.outTypeId = appDataType_t::appHalf; args.outTypeId = AppDataType::appHalf;
if(args.compTypeId == appDataType_t::appHalf) if(args.compTypeId == AppDataType::appHalf)
{ {
profile_reduce_impl<ck::half_t, ck::half_t, ck::half_t>(args.do_verification, profile_reduce_impl<ck::half_t, ck::half_t, ck::half_t>(args.do_verification,
args.init_method, args.init_method,
...@@ -363,7 +363,7 @@ int profile_reduce(int argc, char* argv[]) ...@@ -363,7 +363,7 @@ int profile_reduce(int argc, char* argv[])
args.scales[0], args.scales[0],
args.scales[1]); args.scales[1]);
} }
else if(args.compTypeId == appDataType_t::appFloat) else if(args.compTypeId == AppDataType::appFloat)
{ {
profile_reduce_impl<ck::half_t, float, ck::half_t>(args.do_verification, profile_reduce_impl<ck::half_t, float, ck::half_t>(args.do_verification,
args.init_method, args.init_method,
...@@ -399,16 +399,16 @@ int profile_reduce(int argc, char* argv[]) ...@@ -399,16 +399,16 @@ int profile_reduce(int argc, char* argv[])
else if(args.use_int8) else if(args.use_int8)
{ {
if(!args.compType_assigned) if(!args.compType_assigned)
args.compTypeId = appDataType_t::appInt8; args.compTypeId = AppDataType::appInt8;
if(args.outType_assigned && if(args.outType_assigned &&
(args.outTypeId != appDataType_t::appInt8 && args.outTypeId != appDataType_t::appInt32)) (args.outTypeId != AppDataType::appInt8 && args.outTypeId != AppDataType::appInt32))
args.outTypeId = appDataType_t::appInt32; args.outTypeId = AppDataType::appInt32;
if(!args.outType_assigned) if(!args.outType_assigned)
args.outTypeId = appDataType_t::appInt8; args.outTypeId = AppDataType::appInt8;
if(args.compTypeId == appDataType_t::appInt8) if(args.compTypeId == AppDataType::appInt8)
{ {
profile_reduce_impl<int8_t, int8_t, int8_t>(args.do_verification, profile_reduce_impl<int8_t, int8_t, int8_t>(args.do_verification,
args.init_method, args.init_method,
...@@ -423,7 +423,7 @@ int profile_reduce(int argc, char* argv[]) ...@@ -423,7 +423,7 @@ int profile_reduce(int argc, char* argv[])
args.scales[0], args.scales[0],
args.scales[1]); args.scales[1]);
} }
else if(args.compTypeId == appDataType_t::appInt32) else if(args.compTypeId == AppDataType::appInt32)
{ {
profile_reduce_impl<int8_t, int32_t, int8_t>(args.do_verification, profile_reduce_impl<int8_t, int32_t, int8_t>(args.do_verification,
args.init_method, args.init_method,
...@@ -443,12 +443,12 @@ int profile_reduce(int argc, char* argv[]) ...@@ -443,12 +443,12 @@ int profile_reduce(int argc, char* argv[])
} }
else if(args.use_bf16) else if(args.use_bf16)
{ {
if(args.outType_assigned && (args.outTypeId != appDataType_t::appBFloat16 && if(args.outType_assigned &&
args.outTypeId != appDataType_t::appFloat)) (args.outTypeId != AppDataType::appBFloat16 && args.outTypeId != AppDataType::appFloat))
args.outTypeId = appDataType_t::appFloat; args.outTypeId = AppDataType::appFloat;
if(!args.outType_assigned) if(!args.outType_assigned)
args.outTypeId = appDataType_t::appBFloat16; args.outTypeId = AppDataType::appBFloat16;
profile_reduce_impl<ck::bhalf_t, float, ck::bhalf_t>(args.do_verification, profile_reduce_impl<ck::bhalf_t, float, ck::bhalf_t>(args.do_verification,
args.init_method, args.init_method,
...@@ -465,7 +465,7 @@ int profile_reduce(int argc, char* argv[]) ...@@ -465,7 +465,7 @@ int profile_reduce(int argc, char* argv[])
} }
else else
{ {
if(args.compTypeId == appDataType_t::appFloat) if(args.compTypeId == AppDataType::appFloat)
{ {
profile_reduce_impl<float, float, float>(args.do_verification, profile_reduce_impl<float, float, float>(args.do_verification,
args.init_method, args.init_method,
...@@ -480,7 +480,7 @@ int profile_reduce(int argc, char* argv[]) ...@@ -480,7 +480,7 @@ int profile_reduce(int argc, char* argv[])
args.scales[0], args.scales[0],
args.scales[1]); args.scales[1]);
} }
else if(args.compTypeId == appDataType_t::appDouble) else if(args.compTypeId == AppDataType::appDouble)
{ {
profile_reduce_impl<float, double, float>(args.do_verification, profile_reduce_impl<float, double, float>(args.do_verification,
args.init_method, args.init_method,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment