Unverified Commit 86185bd7 authored by Qianfeng's avatar Qianfeng Committed by GitHub
Browse files

Unify the naming of the math functions used by the host and kernel (#262)

* Use the unified naming for math functions on host and HIP kernel

* Corresponding change/simplification in reduction host/profiler/examples due to unified math functions renaming

* Renaming GetReductionZeroVal() to GetIdentityValue()

* Tiny renaming in profile_reduce_impl.hpp

* More renaming in profile_reduce_impl.hpp

* Replace zeroVal by identiyVal

* Remove ck_ prefix in the naming of ck::math provided functions
parent b6eaf3eb
...@@ -165,8 +165,8 @@ bool profile_gemm_reduce_impl(int do_verification, ...@@ -165,8 +165,8 @@ bool profile_gemm_reduce_impl(int do_verification,
for(int m = 0; m < M; ++m) for(int m = 0; m < M; ++m)
{ {
float d0_acc = d0_reduce_op.GetReductionZeroVal(); float d0_acc = d0_reduce_op.GetIdentityValue();
float d1_acc = d1_reduce_op.GetReductionZeroVal(); float d1_acc = d1_reduce_op.GetIdentityValue();
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
......
...@@ -138,7 +138,6 @@ bool profile_reduce_impl_impl(bool do_verification, ...@@ -138,7 +138,6 @@ bool profile_reduce_impl_impl(bool do_verification,
{ {
using namespace ck::tensor_operation::device; using namespace ck::tensor_operation::device;
using namespace ck::tensor_operation::device::device_reduce_instance; using namespace ck::tensor_operation::device::device_reduce_instance;
using namespace ck::host_reduce;
using ck::host_common::dumpBufferToFile; using ck::host_common::dumpBufferToFile;
constexpr bool op_support_indices = constexpr bool op_support_indices =
...@@ -261,15 +260,17 @@ bool profile_reduce_impl_impl(bool do_verification, ...@@ -261,15 +260,17 @@ bool profile_reduce_impl_impl(bool do_verification,
float best_avg_time = 0; float best_avg_time = 0;
float best_gb_per_sec = 0; float best_gb_per_sec = 0;
using InElementwiseOperation_0 = using InElementwiseOperation =
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>:: typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::
InElementwiseOperation; InElementwiseOperation;
using AccElementwiseOperation_0 = using AccElementwiseOperation =
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>:: typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::
AccElementwiseOperation; AccElementwiseOperation;
using ReduceOperation = typename reduce_binary_operator<AccDataType, ReduceOpId>::opType;
using DeviceReduceInstPtr0 = using DeviceReduceInstPtr0 =
DeviceReducePtr<InElementwiseOperation_0, AccElementwiseOperation_0>; DeviceReducePtr<InElementwiseOperation, AccElementwiseOperation>;
std::vector<DeviceReduceInstPtr0> reduce0_ptrs; std::vector<DeviceReduceInstPtr0> reduce0_ptrs;
...@@ -313,7 +314,9 @@ bool profile_reduce_impl_impl(bool do_verification, ...@@ -313,7 +314,9 @@ bool profile_reduce_impl_impl(bool do_verification,
ReductionHost<InDataType, ReductionHost<InDataType,
AccDataType, AccDataType,
OutDataType, OutDataType,
ReduceOpId, ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
Rank, Rank,
NumReduceDim, NumReduceDim,
PropagateNan, PropagateNan,
...@@ -337,9 +340,8 @@ bool profile_reduce_impl_impl(bool do_verification, ...@@ -337,9 +340,8 @@ bool profile_reduce_impl_impl(bool do_verification,
for(auto& reduce_ptr : reduce0_ptrs) for(auto& reduce_ptr : reduce0_ptrs)
{ {
InElementwiseOperation_0 in_elementwise_op_0(static_cast<int32_t>(reduce_total_length)); InElementwiseOperation in_elementwise_op(static_cast<int32_t>(reduce_total_length));
AccElementwiseOperation_0 acc_elementwise_op_0( AccElementwiseOperation acc_elementwise_op(static_cast<int32_t>(reduce_total_length));
static_cast<int32_t>(reduce_total_length));
auto argument_ptr = reduce_ptr->MakeArgumentPointer(i_inLengths, auto argument_ptr = reduce_ptr->MakeArgumentPointer(i_inLengths,
i_inStrides, i_inStrides,
...@@ -352,8 +354,8 @@ bool profile_reduce_impl_impl(bool do_verification, ...@@ -352,8 +354,8 @@ bool profile_reduce_impl_impl(bool do_verification,
nullptr, nullptr,
out_dev.GetDeviceBuffer(), out_dev.GetDeviceBuffer(),
out_indices_dev.GetDeviceBuffer(), out_indices_dev.GetDeviceBuffer(),
in_elementwise_op_0, in_elementwise_op,
acc_elementwise_op_0); acc_elementwise_op);
if(!reduce_ptr->IsSupportedArgument(argument_ptr.get())) if(!reduce_ptr->IsSupportedArgument(argument_ptr.get()))
continue; continue;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment