Commit 30348daa authored by rocking's avatar rocking
Browse files

[What] Refine naming

[Why] Prepare to add reduceSum
parent f2540aa5
...@@ -85,25 +85,25 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle ...@@ -85,25 +85,25 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle
8>; // CBlockTransferScalarPerVector_NWaveNPerXdl 8>; // CBlockTransferScalarPerVector_NWaveNPerXdl
// clang-format on // clang-format on
constexpr int Rank = 2; constexpr int ReduceRank = 2;
constexpr int NumReduceDim = 1; constexpr int NumReduceDim = 1;
constexpr ck::ReduceTensorOp ReduceOpId = ck::ReduceTensorOp::MAX; constexpr ck::ReduceTensorOp ReduceMaxId = ck::ReduceTensorOp::MAX;
constexpr ck::NanPropagation NanOpt = ck::NanPropagation::PROPAGATE_NAN; constexpr ck::NanPropagation NanOpt = ck::NanPropagation::PROPAGATE_NAN;
constexpr bool PropagateNan = (NanOpt == ck::NanPropagation::NOT_PROPAGATE_NAN) ? false : true; constexpr bool PropagateNan = (NanOpt == ck::NanPropagation::NOT_PROPAGATE_NAN) ? false : true;
// constexpr ck::ReduceTensorIndices_t IndicesOpt = ck::ReduceTensorIndices_t::NO_INDICES; // constexpr ck::ReduceTensorIndices_t IndicesOpt = ck::ReduceTensorIndices_t::NO_INDICES;
using ReduceOperation = typename ck::reduce_binary_operator<CDataType, ReduceOpId>::opType; using ReduceMaxOp = typename ck::reduce_binary_operator<CDataType, ReduceMaxId>::opType;
using InElementwiseOperation = using InElementwiseOperation =
typename ck::reduce_unary_operator<CDataType, ReduceOpId, true, true>::InElementwiseOperation; typename ck::reduce_unary_operator<CDataType, ReduceMaxId, true, true>::InElementwiseOperation;
using AccElementwiseOperation = using AccElementwiseOperation =
typename ck::reduce_unary_operator<CDataType, ReduceOpId, true, true>::AccElementwiseOperation; typename ck::reduce_unary_operator<CDataType, ReduceMaxId, true, true>::AccElementwiseOperation;
using DeviceReduceInstance = using DeviceReduceMaxInstance =
ck::tensor_operation::device::DeviceReduceBlockWise<CDataType, ck::tensor_operation::device::DeviceReduceBlockWise<CDataType,
CDataType, CDataType,
CDataType, CDataType,
Rank, ReduceRank,
NumReduceDim, NumReduceDim,
ReduceOperation, ReduceMaxOp,
InElementwiseOperation, InElementwiseOperation,
AccElementwiseOperation, AccElementwiseOperation,
PropagateNan, PropagateNan,
...@@ -264,7 +264,7 @@ int main(int argc, char* argv[]) ...@@ -264,7 +264,7 @@ int main(int argc, char* argv[])
gemm_invoker.Run(gemm_argument, nrepeat); gemm_invoker.Run(gemm_argument, nrepeat);
// do reduce max // do reduce max
auto reduce_max = DeviceReduceInstance{}; auto reduce_max = DeviceReduceMaxInstance{};
auto wsSizeInBytes = reduce_max.GetWorkspaceSizeInBytes(i_inLengths, reduceDims); auto wsSizeInBytes = reduce_max.GetWorkspaceSizeInBytes(i_inLengths, reduceDims);
DeviceMem ws_dev(wsSizeInBytes); DeviceMem ws_dev(wsSizeInBytes);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment