Unverified Commit 86185bd7 authored by Qianfeng's avatar Qianfeng Committed by GitHub
Browse files

Unify the naming of the math functions used by the host and kernel (#262)

* Use the unified naming for math functions on host and HIP kernel

* Corresponding change/simplification in reduction host/profiler/examples due to unified math functions renaming

* Renaming GetReductionZeroVal() to GetIdentityValue()

* Tiny renaming in profile_reduce_impl.hpp

* More renaming in profile_reduce_impl.hpp

* Replace zeroVal by identiyVal

* Remove ck_ prefix in the naming of ck::math provided functions
parent b6eaf3eb
...@@ -147,8 +147,6 @@ class SimpleAppArgs ...@@ -147,8 +147,6 @@ class SimpleAppArgs
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
using namespace ck::host_reduce;
const std::vector<int> reduceDims{0, 1, 2}; const std::vector<int> reduceDims{0, 1, 2};
const std::vector<int> invariantDims{3}; const std::vector<int> invariantDims{3};
...@@ -254,7 +252,9 @@ int main(int argc, char* argv[]) ...@@ -254,7 +252,9 @@ int main(int argc, char* argv[])
ReductionHost<InDataType, ReductionHost<InDataType,
AccDataType, AccDataType,
OutDataType, OutDataType,
ReduceOpId, ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
Rank, Rank,
NumReduceDim, NumReduceDim,
PropagateNan, PropagateNan,
......
...@@ -108,8 +108,6 @@ int main(int argc, char* argv[]) ...@@ -108,8 +108,6 @@ int main(int argc, char* argv[])
const std::vector<size_t> outLengths = {64, 320, 80}; const std::vector<size_t> outLengths = {64, 320, 80};
using namespace ck::host_reduce;
if(argc == 1) if(argc == 1)
{ {
do_verify = true; do_verify = true;
...@@ -191,7 +189,9 @@ int main(int argc, char* argv[]) ...@@ -191,7 +189,9 @@ int main(int argc, char* argv[])
ReductionHost<InOutDataType, ReductionHost<InOutDataType,
AccDataType, AccDataType,
InOutDataType, InOutDataType,
ReduceOpId, ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
5, // Rank 5, // Rank
2, // NumReduceDim 2, // NumReduceDim
PropagateNan, PropagateNan,
......
...@@ -8,10 +8,12 @@ ...@@ -8,10 +8,12 @@
#include "device.hpp" #include "device.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "host_tensor_generator.hpp" #include "host_tensor_generator.hpp"
#include "host_reduce_util.hpp"
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "tensor_layout.hpp" #include "tensor_layout.hpp"
#include "reduction_enums.hpp" #include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
#include "reduction_functions_accumulate.hpp"
#include "device_pool2d_fwd_nhwc_nhwc.hpp" #include "device_pool2d_fwd_nhwc_nhwc.hpp"
template <typename InDataType, template <typename InDataType,
...@@ -29,19 +31,24 @@ static void pool_host_verify(const Tensor<InDataType>& in, ...@@ -29,19 +31,24 @@ static void pool_host_verify(const Tensor<InDataType>& in,
const std::array<ck::index_t, 2>& in_left_pads, const std::array<ck::index_t, 2>& in_left_pads,
const std::array<ck::index_t, 2>& /*in_right_pads*/) const std::array<ck::index_t, 2>& /*in_right_pads*/)
{ {
using namespace ck::host_reduce;
const int32_t divider = window_spatial_lengths[0] * window_spatial_lengths[1]; const int32_t divider = window_spatial_lengths[0] * window_spatial_lengths[1];
const auto PreUnaryOp = PreUnaryOpFn<AccDataType, ReduceOpId>(divider); using ReduceOperation = typename ck::reduce_binary_operator<AccDataType, ReduceOpId>::opType;
const auto PosUnaryOp = PosUnaryOpFn<AccDataType, ReduceOpId>(divider); using InElementwiseOperation = typename ck::
reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation;
using AccElementwiseOperation = typename ck::
reduce_unary_operator<AccDataType, ReduceOpId, true, true>::AccElementwiseOperation;
const InElementwiseOperation in_elementwise_op(divider);
const AccElementwiseOperation acc_elementwise_op(divider);
if constexpr(!OutputIndex) if constexpr(!OutputIndex)
{ {
auto opReduce = ReduceOpFn<AccDataType, ReduceOpId>(); using Accumulation =
ck::detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
auto f_nchw = [&](auto n, auto c, auto ho, auto wo) { auto f_nchw = [&](auto n, auto c, auto ho, auto wo) {
auto accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>(); auto accuVal = ReduceOperation::GetIdentityValue();
for(ck::index_t y = 0; y < window_spatial_lengths[0]; ++y) for(ck::index_t y = 0; y < window_spatial_lengths[0]; ++y)
{ {
...@@ -54,14 +61,14 @@ static void pool_host_verify(const Tensor<InDataType>& in, ...@@ -54,14 +61,14 @@ static void pool_host_verify(const Tensor<InDataType>& in,
{ {
AccDataType currVal = static_cast<AccDataType>(in(n, c, hi, wi)); AccDataType currVal = static_cast<AccDataType>(in(n, c, hi, wi));
PreUnaryOp(currVal); in_elementwise_op(currVal, currVal);
binop_with_nan_check<AccDataType, PropagateNan>(opReduce, accuVal, currVal); Accumulation::Calculate(accuVal, currVal);
} }
} }
} }
PosUnaryOp(accuVal); acc_elementwise_op(accuVal, accuVal);
out(n, c, ho, wo) = accuVal; out(n, c, ho, wo) = accuVal;
}; };
...@@ -74,10 +81,12 @@ static void pool_host_verify(const Tensor<InDataType>& in, ...@@ -74,10 +81,12 @@ static void pool_host_verify(const Tensor<InDataType>& in,
} }
else else
{ {
auto opReduce = ReduceOpFn2<AccDataType, ReduceOpId>(); using Accumulation = ck::detail::AccumulateWithIndexAndNanCheck<PropagateNan,
ReduceOperation,
auto f_nchw = [&](auto n, auto c, auto ho, auto wo) { AccDataType,
auto accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>(); IndexDataType>;
auto f_nchw = [&](auto n, auto c, auto ho, auto wo) {
auto accuVal = ReduceOperation::GetIdentityValue();
IndexDataType accuIndex = 0; IndexDataType accuIndex = 0;
for(ck::index_t y = 0; y < window_spatial_lengths[0]; ++y) for(ck::index_t y = 0; y < window_spatial_lengths[0]; ++y)
...@@ -92,15 +101,14 @@ static void pool_host_verify(const Tensor<InDataType>& in, ...@@ -92,15 +101,14 @@ static void pool_host_verify(const Tensor<InDataType>& in,
AccDataType currVal = static_cast<AccDataType>(in(n, c, hi, wi)); AccDataType currVal = static_cast<AccDataType>(in(n, c, hi, wi));
IndexDataType currIndex = y * window_spatial_lengths[1] + x; IndexDataType currIndex = y * window_spatial_lengths[1] + x;
PreUnaryOp(currVal); in_elementwise_op(currVal, currVal);
binop_with_index_and_nan_check<AccDataType, IndexDataType, PropagateNan>( Accumulation::Calculate(accuVal, currVal, accuIndex, currIndex);
opReduce, accuVal, currVal, accuIndex, currIndex);
} }
} }
} }
PosUnaryOp(accuVal); acc_elementwise_op(accuVal, accuVal);
out(n, c, ho, wo) = accuVal; out(n, c, ho, wo) = accuVal;
out_indices(n, c, ho, wo) = accuIndex; out_indices(n, c, ho, wo) = accuIndex;
...@@ -139,8 +147,6 @@ bool pool_test(bool do_verification, ...@@ -139,8 +147,6 @@ bool pool_test(bool do_verification,
ck::index_t in_right_pad_h, ck::index_t in_right_pad_h,
ck::index_t in_right_pad_w) ck::index_t in_right_pad_w)
{ {
using namespace ck::host_reduce;
using DevicePoolFwdInstance = using DevicePoolFwdInstance =
ck::tensor_operation::device::DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C< ck::tensor_operation::device::DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C<
InDataType, // InDataType InDataType, // InDataType
......
...@@ -27,8 +27,6 @@ static constexpr bool PropagateNan = false; ...@@ -27,8 +27,6 @@ static constexpr bool PropagateNan = false;
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
using namespace ck::host_reduce;
bool do_verification; bool do_verification;
int init_method; int init_method;
bool time_kernel; bool time_kernel;
......
...@@ -27,8 +27,6 @@ static constexpr bool PropagateNan = false; ...@@ -27,8 +27,6 @@ static constexpr bool PropagateNan = false;
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
using namespace ck::host_reduce;
bool do_verification; bool do_verification;
int init_method; int init_method;
bool time_kernel; bool time_kernel;
......
...@@ -236,7 +236,7 @@ int main(int argc, char* argv[]) ...@@ -236,7 +236,7 @@ int main(int argc, char* argv[])
for(int m = 0; m < M; ++m) for(int m = 0; m < M; ++m)
{ {
ReduceAccDataType d_acc = d_reduce_op.GetReductionZeroVal(); ReduceAccDataType d_acc = d_reduce_op.GetIdentityValue();
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
d_reduce_op(d_acc, c_m_n_host_result(m, n)); d_reduce_op(d_acc, c_m_n_host_result(m, n));
......
...@@ -261,8 +261,8 @@ int main(int argc, char* argv[]) ...@@ -261,8 +261,8 @@ int main(int argc, char* argv[])
for(int m = 0; m < M; ++m) for(int m = 0; m < M; ++m)
{ {
float d0_acc = d0_reduce_op.GetReductionZeroVal(); float d0_acc = d0_reduce_op.GetIdentityValue();
float d1_acc = d1_reduce_op.GetReductionZeroVal(); float d1_acc = d1_reduce_op.GetIdentityValue();
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
......
...@@ -259,8 +259,8 @@ int main(int argc, char* argv[]) ...@@ -259,8 +259,8 @@ int main(int argc, char* argv[])
{ {
for(int m = 0; m < M; ++m) for(int m = 0; m < M; ++m)
{ {
float d0_acc = d0_reduce_op.GetReductionZeroVal(); float d0_acc = d0_reduce_op.GetIdentityValue();
float d1_acc = d1_reduce_op.GetReductionZeroVal(); float d1_acc = d1_reduce_op.GetIdentityValue();
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
......
...@@ -157,8 +157,8 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n, ...@@ -157,8 +157,8 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
auto reduceSumOpInst = ReduceSumOp{}; auto reduceSumOpInst = ReduceSumOp{};
for(int m = 0; m < M; ++m) for(int m = 0; m < M; ++m)
{ {
float mean_acc = reduceSumOpInst.GetReductionZeroVal(); float mean_acc = reduceSumOpInst.GetIdentityValue();
float square_mean_acc = reduceSumOpInst.GetReductionZeroVal(); float square_mean_acc = reduceSumOpInst.GetIdentityValue();
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
......
...@@ -348,8 +348,8 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE ...@@ -348,8 +348,8 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
if constexpr(use_multiblock) if constexpr(use_multiblock)
{ {
const auto zeroVal = const auto identityVal =
ck::reduce::GetReductionZeroValueForInMemoryDataOperation<OutDataType>( ck::reduce::GetIdentityValueueForInMemoryDataOperation<OutDataType>(
OutMemoryDataOperation); OutMemoryDataOperation);
const auto kernel_pre = const auto kernel_pre =
...@@ -362,7 +362,7 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE ...@@ -362,7 +362,7 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
0, 0,
out_grid_desc_m_2, out_grid_desc_m_2,
arg.out_dev_, arg.out_dev_,
zeroVal); identityVal);
}; };
avg_time += launch_and_time_kernel(stream_config, avg_time += launch_and_time_kernel(stream_config,
......
#pragma once #pragma once
#include "data_type.hpp" #include "data_type.hpp"
#include "math_v2.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -296,7 +297,7 @@ struct UnaryAbs<float, float> ...@@ -296,7 +297,7 @@ struct UnaryAbs<float, float>
{ {
__host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; }; __host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(float& y, const float& x) const { y = abs(x); }; __host__ __device__ void operator()(float& y, const float& x) const { y = ck::math::abs(x); };
}; };
template <> template <>
...@@ -304,7 +305,7 @@ struct UnaryAbs<half_t, half_t> ...@@ -304,7 +305,7 @@ struct UnaryAbs<half_t, half_t>
{ {
__host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; }; __host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(half_t& y, const half_t& x) const { y = __habs(x); }; __host__ __device__ void operator()(half_t& y, const half_t& x) const { y = ck::math::abs(x); };
}; };
template <> template <>
...@@ -312,7 +313,7 @@ struct UnaryAbs<double, double> ...@@ -312,7 +313,7 @@ struct UnaryAbs<double, double>
{ {
__host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; }; __host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(double& y, const double& x) const { y = abs(x); }; __host__ __device__ void operator()(double& y, const double& x) const { y = ck::math::abs(x); };
}; };
template <> template <>
...@@ -320,12 +321,7 @@ struct UnaryAbs<int8_t, int8_t> ...@@ -320,12 +321,7 @@ struct UnaryAbs<int8_t, int8_t>
{ {
__host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; }; __host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(int8_t& y, const int8_t& x) const __host__ __device__ void operator()(int8_t& y, const int8_t& x) const { y = ck::math::abs(x); };
{
int8_t sgn = x >> (8 - 1);
y = (x ^ sgn) - sgn;
};
}; };
template <typename Y, typename X> template <typename Y, typename X>
...@@ -336,7 +332,7 @@ struct UnarySqrt<float, float> ...@@ -336,7 +332,7 @@ struct UnarySqrt<float, float>
{ {
__host__ __device__ UnarySqrt(const int32_t divider = 1) { (void)divider; }; __host__ __device__ UnarySqrt(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(float& y, const float& x) const { y = sqrtf(x); }; __host__ __device__ void operator()(float& y, const float& x) const { y = ck::math::sqrt(x); };
}; };
template <> template <>
...@@ -344,7 +340,10 @@ struct UnarySqrt<double, double> ...@@ -344,7 +340,10 @@ struct UnarySqrt<double, double>
{ {
__host__ __device__ UnarySqrt(const int32_t divider = 1) { (void)divider; }; __host__ __device__ UnarySqrt(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(double& y, const double& x) const { y = sqrt(x); }; __host__ __device__ void operator()(double& y, const double& x) const
{
y = ck::math::sqrt(x);
};
}; };
} // namespace element_wise } // namespace element_wise
......
...@@ -171,7 +171,7 @@ struct GridwiseReduction_mk_to_m_multiblock ...@@ -171,7 +171,7 @@ struct GridwiseReduction_mk_to_m_multiblock
AccDataType beta, AccDataType beta,
OutDataType* const __restrict__ p_out_value_global) OutDataType* const __restrict__ p_out_value_global)
{ {
const auto zeroVal = ReduceOperation::GetReductionZeroVal(); const auto identityVal = ReduceOperation::GetIdentityValue();
// LDS // LDS
__shared__ AccDataType p_reduce_work_buffer[BlockSize]; __shared__ AccDataType p_reduce_work_buffer[BlockSize];
...@@ -179,7 +179,7 @@ struct GridwiseReduction_mk_to_m_multiblock ...@@ -179,7 +179,7 @@ struct GridwiseReduction_mk_to_m_multiblock
const auto in_global_val_buf = const auto in_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global, make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(), in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(zeroVal)); type_convert<InDataType>(identityVal));
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_value_global, out_grid_desc_m.GetElementSpaceSize()); p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
...@@ -191,7 +191,7 @@ struct GridwiseReduction_mk_to_m_multiblock ...@@ -191,7 +191,7 @@ struct GridwiseReduction_mk_to_m_multiblock
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf; StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; }); static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = identityVal; });
const index_t thread_local_id = get_thread_local_1d_id(); const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id(); const index_t block_global_id = get_block_1d_id();
...@@ -358,12 +358,12 @@ struct GridwiseReduction_mk_to_m_multiblock ...@@ -358,12 +358,12 @@ struct GridwiseReduction_mk_to_m_multiblock
__shared__ AccDataType p_reduce_work_val_buffer[BlockSize]; __shared__ AccDataType p_reduce_work_val_buffer[BlockSize];
__shared__ IndexDataType p_reduce_work_idx_buffer[BlockSize]; __shared__ IndexDataType p_reduce_work_idx_buffer[BlockSize];
const auto zeroVal = ReduceOperation::GetReductionZeroVal(); const auto identityVal = ReduceOperation::GetIdentityValue();
const auto in_global_val_buf = const auto in_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global, make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(), in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(zeroVal)); type_convert<InDataType>(identityVal));
const auto in_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto in_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize()); p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize());
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
...@@ -418,7 +418,7 @@ struct GridwiseReduction_mk_to_m_multiblock ...@@ -418,7 +418,7 @@ struct GridwiseReduction_mk_to_m_multiblock
thread_k_cluster_id * KThreadSliceSize)); thread_k_cluster_id * KThreadSliceSize));
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) = zeroVal; accu_value_buf(I) = identityVal;
accu_index_buf(I) = 0; accu_index_buf(I) = 0;
}); });
...@@ -459,7 +459,7 @@ struct GridwiseReduction_mk_to_m_multiblock ...@@ -459,7 +459,7 @@ struct GridwiseReduction_mk_to_m_multiblock
in_thread_idx_buf); in_thread_idx_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
AccDataType tmpValue = zeroVal; AccDataType tmpValue = identityVal;
IndexDataType tmpIndex = 0; IndexDataType tmpIndex = 0;
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
...@@ -512,7 +512,7 @@ struct GridwiseReduction_mk_to_m_multiblock ...@@ -512,7 +512,7 @@ struct GridwiseReduction_mk_to_m_multiblock
in_thread_val_buf(Number<offset>{})); in_thread_val_buf(Number<offset>{}));
}); });
AccDataType tmpValue = zeroVal; AccDataType tmpValue = identityVal;
IndexDataType tmpIndex = 0; IndexDataType tmpIndex = 0;
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
......
...@@ -135,12 +135,12 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -135,12 +135,12 @@ struct GridwiseReduction_mk_to_m_threadwise
ReduceOperation, ReduceOperation,
PropagateNan>; PropagateNan>;
const auto zeroVal = ReduceOperation::GetReductionZeroVal(); const auto identityVal = ReduceOperation::GetIdentityValue();
const auto in_global_val_buf = const auto in_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global, make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(), in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(zeroVal)); type_convert<InDataType>(identityVal));
auto dst_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto dst_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_value_global, out_grid_desc_m.GetElementSpaceSize()); p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
...@@ -149,7 +149,7 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -149,7 +149,7 @@ struct GridwiseReduction_mk_to_m_threadwise
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf; StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; }); static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = identityVal; });
const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
...@@ -276,12 +276,12 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -276,12 +276,12 @@ struct GridwiseReduction_mk_to_m_threadwise
(void)acc_elementwise_op; (void)acc_elementwise_op;
const auto zeroVal = ReduceOperation::GetReductionZeroVal(); const auto identityVal = ReduceOperation::GetIdentityValue();
const auto in_global_val_buf = const auto in_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global, make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(), in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(zeroVal)); type_convert<InDataType>(identityVal));
const auto in_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto in_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize()); p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize());
...@@ -303,7 +303,7 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -303,7 +303,7 @@ struct GridwiseReduction_mk_to_m_threadwise
StaticBuffer<AddressSpaceEnum::Vgpr, IndexDataType, MThreadSliceSize, true> accu_index_buf; StaticBuffer<AddressSpaceEnum::Vgpr, IndexDataType, MThreadSliceSize, true> accu_index_buf;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) = zeroVal; accu_value_buf(I) = identityVal;
accu_index_buf(I) = 0; accu_index_buf(I) = 0;
}); });
......
...@@ -816,10 +816,10 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -816,10 +816,10 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
false>; false>;
// Global write Gemm shuffle + reduction // Global write Gemm shuffle + reduction
const auto d_zeroVal = DReduceOperation::GetReductionZeroVal(); const auto d_identityVal = DReduceOperation::GetIdentityValue();
static_for<0, mreduce_per_thread, 1>{}( static_for<0, mreduce_per_thread, 1>{}(
[&](auto I) { d_thread_buf(I) = d_zeroVal; }); [&](auto I) { d_thread_buf(I) = d_identityVal; });
// reduce in VGPR // reduce in VGPR
static_for<0, mreduce_per_thread, 1>{}([&](auto im) { static_for<0, mreduce_per_thread, 1>{}([&](auto im) {
......
...@@ -3,11 +3,13 @@ ...@@ -3,11 +3,13 @@
#include <cmath> #include <cmath>
#include "data_type.hpp" #include "data_type.hpp"
#include "half.hpp" #include "type.hpp"
namespace ck { namespace ck {
namespace math { namespace math {
// math functions for the host, some are implemented by calling C++ std functions
static inline __host__ float abs(float x) { return std::abs(x); }; static inline __host__ float abs(float x) { return std::abs(x); };
static inline __host__ double abs(double x) { return std::abs(x); }; static inline __host__ double abs(double x) { return std::abs(x); };
...@@ -28,26 +30,26 @@ static inline __host__ int32_t abs(int32_t x) ...@@ -28,26 +30,26 @@ static inline __host__ int32_t abs(int32_t x)
static inline __host__ half_t abs(half_t x) static inline __host__ half_t abs(half_t x)
{ {
half_float::half xx = *reinterpret_cast<half_float::half*>(&x); uint16_t xx = ck::bit_cast<uint16_t>(x);
half_float::half abs_xx = half_float::abs(xx); uint16_t abs_xx = xx & 0x7fff;
half_t abs_x = *reinterpret_cast<half_t*>(&abs_xx); half_t abs_x = ck::bit_cast<half_t>(abs_xx);
return abs_x; return abs_x;
}; };
static inline __host__ float isnan(float x) { return std::isnan(x); }; static inline __host__ bool isnan(float x) { return std::isnan(x); };
static inline __host__ double isnan(double x) { return std::isnan(x); }; static inline __host__ bool isnan(double x) { return std::isnan(x); };
static inline __host__ int8_t isnan(int8_t x) static inline __host__ bool isnan(int8_t x)
{ {
(void)x; (void)x;
return false; return false;
}; };
static inline __host__ int32_t isnan(int32_t x) static inline __host__ bool isnan(int32_t x)
{ {
(void)x; (void)x;
return false; return false;
...@@ -55,11 +57,59 @@ static inline __host__ int32_t isnan(int32_t x) ...@@ -55,11 +57,59 @@ static inline __host__ int32_t isnan(int32_t x)
static inline __host__ bool isnan(half_t x) static inline __host__ bool isnan(half_t x)
{ {
half_float::half xx = *reinterpret_cast<half_float::half*>(&x); uint16_t xx = ck::bit_cast<uint16_t>(x);
return (xx & 0x7FFF) > 0x7C00;
};
static inline __host__ float sqrt(float x) { return std::sqrt(x); };
static inline __host__ double sqrt(double x) { return std::sqrt(x); };
// math functions for the HIP kernel, some are implemented by calling hip builtin functions
static inline __device__ float abs(float x) { return ::abs(x); };
static inline __device__ double abs(double x) { return ::abs(x); };
static inline __device__ int8_t abs(int8_t x)
{
int8_t sgn = x >> (8 - 1);
return (x ^ sgn) - sgn;
};
static inline __device__ int32_t abs(int32_t x)
{
int32_t sgn = x >> (32 - 1);
return (x ^ sgn) - sgn;
};
static inline __device__ half_t abs(half_t x) { return ::__habs(x); };
static inline __device__ bool isnan(float x) { return ::isnan(x); };
static inline __device__ bool isnan(double x) { return ::isnan(x); };
static inline __device__ bool isnan(int8_t x)
{
(void)x;
return false;
};
return half_float::isnan(xx); static inline __device__ bool isnan(int32_t x)
{
(void)x;
return false;
}; };
static inline __device__ bool isnan(half_t x) { return ::__hisnan(x); };
static inline __device__ float sqrt(float x) { return ::sqrtf(x); };
static inline __device__ double sqrt(double x) { return ::sqrt(x); };
} // namespace math } // namespace math
} // namespace ck } // namespace ck
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#define CK_REDUCTION_FUNCTIONS_BINOP_HPP #define CK_REDUCTION_FUNCTIONS_BINOP_HPP
#include "data_type.hpp" #include "data_type.hpp"
#include "math_v2.hpp"
#include "reduction_common.hpp" #include "reduction_common.hpp"
#include "reduction_operator.hpp" #include "reduction_operator.hpp"
...@@ -34,18 +35,6 @@ ...@@ -34,18 +35,6 @@
namespace ck { namespace ck {
namespace detail { namespace detail {
template <typename T>
static inline __device__ bool is_nan(T x)
{
return (isnan(x));
};
template <>
inline __device__ bool is_nan<half_t>(half_t x)
{
return (__hisnan(x));
};
template <bool PropagateNan, typename ReduceOperation, typename AccDataType> template <bool PropagateNan, typename ReduceOperation, typename AccDataType>
struct AccumulateWithNanCheck; struct AccumulateWithNanCheck;
...@@ -53,7 +42,7 @@ template <typename ReduceOperation, typename AccDataType> ...@@ -53,7 +42,7 @@ template <typename ReduceOperation, typename AccDataType>
struct AccumulateWithNanCheck<false, ReduceOperation, AccDataType> struct AccumulateWithNanCheck<false, ReduceOperation, AccDataType>
{ {
// cppcheck-suppress constParameter // cppcheck-suppress constParameter
__device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal) __host__ __device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal)
{ {
ReduceOperation{}(accuVal, currVal); ReduceOperation{}(accuVal, currVal);
}; };
...@@ -62,9 +51,11 @@ struct AccumulateWithNanCheck<false, ReduceOperation, AccDataType> ...@@ -62,9 +51,11 @@ struct AccumulateWithNanCheck<false, ReduceOperation, AccDataType>
template <typename ReduceOperation, typename AccDataType> template <typename ReduceOperation, typename AccDataType>
struct AccumulateWithNanCheck<true, ReduceOperation, AccDataType> struct AccumulateWithNanCheck<true, ReduceOperation, AccDataType>
{ {
__device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal) __host__ __device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal)
{ {
if(is_nan(currVal)) using ck::math::isnan;
if(isnan(currVal))
{ {
accuVal = currVal; accuVal = currVal;
} }
...@@ -81,7 +72,7 @@ struct AccumulateWithIndexAndNanCheck; ...@@ -81,7 +72,7 @@ struct AccumulateWithIndexAndNanCheck;
template <typename ReduceOperation, typename AccDataType, typename IndexDataType> template <typename ReduceOperation, typename AccDataType, typename IndexDataType>
struct AccumulateWithIndexAndNanCheck<false, ReduceOperation, AccDataType, IndexDataType> struct AccumulateWithIndexAndNanCheck<false, ReduceOperation, AccDataType, IndexDataType>
{ {
__device__ static inline void __host__ __device__ static inline void
// cppcheck-suppress constParameter // cppcheck-suppress constParameter
Calculate(AccDataType& accuVal, Calculate(AccDataType& accuVal,
AccDataType currVal, AccDataType currVal,
...@@ -101,12 +92,14 @@ template <typename ReduceOperation, typename AccDataType, typename IndexDataType ...@@ -101,12 +92,14 @@ template <typename ReduceOperation, typename AccDataType, typename IndexDataType
struct AccumulateWithIndexAndNanCheck<true, ReduceOperation, AccDataType, IndexDataType> struct AccumulateWithIndexAndNanCheck<true, ReduceOperation, AccDataType, IndexDataType>
{ {
// The method is called when the ReduceOperation is indexable and the user asked for indices // The method is called when the ReduceOperation is indexable and the user asked for indices
__device__ static inline void Calculate(AccDataType& accuVal, __host__ __device__ static inline void Calculate(AccDataType& accuVal,
AccDataType currVal, AccDataType currVal,
IndexDataType& accuIndex, IndexDataType& accuIndex,
IndexDataType currIndex) IndexDataType currIndex)
{ {
if(is_nan(currVal)) using ck::math::isnan;
if(isnan(currVal))
{ {
accuVal = currVal; accuVal = currVal;
accuIndex = currIndex; accuIndex = currIndex;
......
...@@ -36,7 +36,7 @@ namespace reduce { ...@@ -36,7 +36,7 @@ namespace reduce {
// Every binary operator used in reduction is represented by a templated functor class. Each functor // Every binary operator used in reduction is represented by a templated functor class. Each functor
// class must provide at least // class must provide at least
// three members: // three members:
// 1) GetReductionZeroVal() -- the interface to return the "identity element" for the binary // 1) GetIdentityValue() -- the interface to return the "identity element" for the binary
// operator, "identity element" is the unique // operator, "identity element" is the unique
// element in the algebraic space that doesn't affect the value of other elements // element in the algebraic space that doesn't affect the value of other elements
// when operated against them, and the concept is similar to zero vector in // when operated against them, and the concept is similar to zero vector in
...@@ -59,7 +59,7 @@ struct Add ...@@ -59,7 +59,7 @@ struct Add
{ {
using dataType = T; using dataType = T;
__host__ __device__ static constexpr T GetReductionZeroVal() { return static_cast<T>(0.0f); }; __host__ __device__ static constexpr T GetIdentityValue() { return static_cast<T>(0.0f); };
__device__ static constexpr bool __device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation) IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
...@@ -76,7 +76,7 @@ struct Mul ...@@ -76,7 +76,7 @@ struct Mul
{ {
using dataType = T; using dataType = T;
__host__ __device__ static constexpr T GetReductionZeroVal() { return static_cast<T>(1.0f); }; __host__ __device__ static constexpr T GetIdentityValue() { return static_cast<T>(1.0f); };
__device__ static constexpr bool __device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation) IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
...@@ -92,7 +92,7 @@ struct Max ...@@ -92,7 +92,7 @@ struct Max
{ {
using dataType = T; using dataType = T;
__host__ __device__ static constexpr T GetReductionZeroVal() __host__ __device__ static constexpr T GetIdentityValue()
{ {
return NumericLimits<T>::Lowest(); return NumericLimits<T>::Lowest();
}; };
...@@ -125,10 +125,7 @@ struct Min ...@@ -125,10 +125,7 @@ struct Min
{ {
using dataType = T; using dataType = T;
__host__ __device__ static constexpr T GetReductionZeroVal() __host__ __device__ static constexpr T GetIdentityValue() { return NumericLimits<T>::Max(); };
{
return NumericLimits<T>::Max();
};
__device__ static constexpr bool __device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation) IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
...@@ -158,7 +155,7 @@ struct AMax ...@@ -158,7 +155,7 @@ struct AMax
{ {
using dataType = T; using dataType = T;
__host__ __device__ static constexpr T GetReductionZeroVal() { return static_cast<T>(0.0f); }; __host__ __device__ static constexpr T GetIdentityValue() { return static_cast<T>(0.0f); };
__device__ static constexpr bool __device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation) IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
...@@ -184,7 +181,7 @@ struct AMax ...@@ -184,7 +181,7 @@ struct AMax
}; };
template <typename T> template <typename T>
T GetReductionZeroValueForInMemoryDataOperation(InMemoryDataOperationEnum operation) T GetIdentityValueueForInMemoryDataOperation(InMemoryDataOperationEnum operation)
{ {
T result = ck::type_convert<T>(0.0f); T result = ck::type_convert<T>(0.0f);
......
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2020 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef GUARD_HOST_REDUCE_UTIL_HPP
#define GUARD_HOST_REDUCE_UTIL_HPP
#include <limits>
#include <cmath>
#include <functional>
#include "reduction_enums.hpp"
#include "data_type.hpp"
#include "math_v2.hpp"
namespace ck {
namespace host_reduce {
using ck::NanPropagation;
using ck::ReduceTensorOp;
template <typename AccDataType, ReduceTensorOp ReduceOpId>
__host__ static inline std::function<void(AccDataType&)> PreUnaryOpFn(int)
{
using ck::math::abs;
if constexpr(ReduceOpId == ReduceTensorOp::NORM1)
{
return ([&](AccDataType& a_) { a_ = abs(a_); });
}
else if constexpr(ReduceOpId == ReduceTensorOp::NORM2)
{
return ([&](AccDataType& a_) { a_ = a_ * a_; });
}
else if constexpr(ReduceOpId == ReduceTensorOp::AMAX)
{
return ([&](AccDataType& a_) { a_ = abs(a_); });
}
else
{
// ReduceTensorOp::AVG:
// ReduceTensorOp::ADD:
// ReduceTensorOp::MUL:
// ReduceTensorOp::MIN:
// ReduceTensorOp::MAX:
return ([&](AccDataType&) {});
};
};
template <typename AccDataType, ReduceTensorOp ReduceOpId>
__host__ static inline std::function<void(AccDataType&)> PosUnaryOpFn(int32_t divider)
{
using std::sqrt;
if constexpr(ReduceOpId == ReduceTensorOp::NORM2)
{
return ([&](AccDataType& a_) { a_ = sqrt(a_); });
}
else if constexpr(ReduceOpId == ReduceTensorOp::AVG)
{
return ([&, divider](AccDataType& a_) {
a_ = a_ / static_cast<AccDataType>(static_cast<float>(divider));
});
}
else
{
// ReduceTensorOp::ADD:
// ReduceTensorOp::NORM1:
// ReduceTensorOp::MUL:
// ReduceTensorOp::MIN:
// ReduceTensorOp::MAX:
// ReduceTensorOp::AMAX:
return ([&](AccDataType&) {});
}
};
template <typename AccDataType, ReduceTensorOp ReduceOpId>
__host__ static inline std::function<void(AccDataType&, AccDataType)> ReduceOpFn()
{
if constexpr(ReduceOpId == ReduceTensorOp::ADD || ReduceOpId == ReduceTensorOp::AVG ||
ReduceOpId == ReduceTensorOp::NORM1 || ReduceOpId == ReduceTensorOp::NORM2)
{
return ([&](AccDataType& a_, AccDataType b_) { a_ = a_ + b_; });
}
else if constexpr(ReduceOpId == ReduceTensorOp::MUL)
{
return ([&](AccDataType& a_, AccDataType b_) { a_ = a_ * b_; });
}
else if constexpr(ReduceOpId == ReduceTensorOp::MIN)
{
return ([&](AccDataType& a_, AccDataType b_) {
if(a_ > b_)
a_ = b_;
});
}
else if constexpr(ReduceOpId == ReduceTensorOp::MAX || ReduceOpId == ReduceTensorOp::AMAX)
{
return ([&](AccDataType& a_, AccDataType b_) {
if(a_ < b_)
a_ = b_;
});
}
};
template <typename AccDataType, ReduceTensorOp ReduceOpId>
__host__ static inline std::function<void(AccDataType&, AccDataType, bool& changed)> ReduceOpFn2()
{
if constexpr(ReduceOpId == ReduceTensorOp::MIN)
{
return ([&](AccDataType& a_, AccDataType b_, bool& changed) {
if(a_ > b_)
{
a_ = b_;
changed = true;
}
else
changed = false;
});
}
else if constexpr(ReduceOpId == ReduceTensorOp::MAX || ReduceOpId == ReduceTensorOp::AMAX)
{
return ([&](AccDataType& a_, AccDataType b_, bool& changed) {
if(a_ < b_)
{
a_ = b_;
changed = true;
}
else
changed = false;
});
}
else
{
// ReduceTensorOp::ADD:
// ReduceTensorOp::MUL:
// ReduceTensorOp::AVG:
// ReduceTensorOp::NORM1:
// ReduceTensorOp::NORM2:
return (std::function<void(AccDataType&, AccDataType, bool&)>{});
};
};
template <typename AccDataType, ReduceTensorOp ReduceOpId>
__host__ static inline AccDataType ReduceOpZeroVal()
{
if constexpr(ReduceOpId == ReduceTensorOp::MUL)
{
return (static_cast<AccDataType>(1.0f));
}
else if constexpr(ReduceOpId == ReduceTensorOp::MIN)
{
return (ck::NumericLimits<AccDataType>::Max());
}
else if constexpr(ReduceOpId == ReduceTensorOp::MAX)
{
return (ck::NumericLimits<AccDataType>::Lowest());
}
else if constexpr(ReduceOpId == ReduceTensorOp::AMAX)
{
return (static_cast<AccDataType>(0.0f));
}
else
{
// ReduceTensorOp::ADD
// ReduceTensorOp::AVG
// ReduceTensorOp::NORM1
// ReduceTensorOp::NORM2
return (static_cast<AccDataType>(0.0f));
};
};
template <typename AccDataType, bool PropagateNan>
__host__ static inline void
binop_with_nan_check(std::function<void(AccDataType&, AccDataType)> opReduce,
AccDataType& accuVal,
AccDataType currVal)
{
using ck::math::isnan;
if constexpr(!PropagateNan)
{
opReduce(accuVal, currVal);
}
else
{
if(isnan(currVal))
accuVal = currVal;
else
opReduce(accuVal, currVal);
};
};
template <typename AccDataType, typename IndexDataType, bool PropagateNan>
__host__ static inline void
binop_with_index_and_nan_check(std::function<void(AccDataType&, AccDataType, bool&)> opReduce,
AccDataType& accuVal,
AccDataType currVal,
IndexDataType& accuIndex,
IndexDataType currIndex)
{
using ck::math::isnan;
if constexpr(!PropagateNan)
{
bool changed;
opReduce(accuVal, currVal, changed);
if(changed)
accuIndex = currIndex;
}
else
{
if(isnan(currVal))
{
accuVal = currVal;
accuIndex = currIndex;
}
else
{
bool changed;
opReduce(accuVal, currVal, changed);
if(changed)
accuIndex = currIndex;
};
};
};
}; // namespace host_reduce
}; // namespace ck
#endif
...@@ -33,10 +33,10 @@ ...@@ -33,10 +33,10 @@
#include "reduction_enums.hpp" #include "reduction_enums.hpp"
#include "reduction_common.hpp" #include "reduction_common.hpp"
#include "host_reduce_util.hpp"
#include "host_common_util.hpp" #include "host_common_util.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "data_type.hpp" #include "data_type.hpp"
#include "reduction_functions_accumulate.hpp"
template <int NDim> template <int NDim>
static void get_all_indexes(const std::array<size_t, NDim>& dimLengths, static void get_all_indexes(const std::array<size_t, NDim>& dimLengths,
...@@ -106,11 +106,13 @@ static size_t get_offset_from_index(const std::vector<size_t>& strides, ...@@ -106,11 +106,13 @@ static size_t get_offset_from_index(const std::vector<size_t>& strides,
template <typename InDataType, template <typename InDataType,
typename AccDataType, typename AccDataType,
typename OutDataType, typename OutDataType,
ck::ReduceTensorOp ReduceOpId, typename ReduceOperation,
typename InElementwiseOperation,
typename AccElementwiseOperation,
int Rank, int Rank,
int NumReduceDim, int NumReduceDim,
bool PropagateNan, bool PropagateNan,
bool NeedIndices> bool OutputIndex>
struct ReductionHost struct ReductionHost
{ {
using IndexDataType = int32_t; using IndexDataType = int32_t;
...@@ -122,8 +124,6 @@ struct ReductionHost ...@@ -122,8 +124,6 @@ struct ReductionHost
std::vector<int> reduceDims; std::vector<int> reduceDims;
IndexDataType divider; IndexDataType divider;
std::function<void(AccDataType&)> preUnaryOp;
std::function<void(AccDataType&)> posUnaryOp;
std::array<size_t, NumReduceDim> reduceLengths; std::array<size_t, NumReduceDim> reduceLengths;
std::array<size_t, NumReduceDim> reduceStrides; std::array<size_t, NumReduceDim> reduceStrides;
std::array<size_t, NumInvariantDim> invariantLengths; std::array<size_t, NumInvariantDim> invariantLengths;
...@@ -137,9 +137,6 @@ struct ReductionHost ...@@ -137,9 +137,6 @@ struct ReductionHost
const std::vector<int>& invariantDims_, const std::vector<int>& invariantDims_,
const std::vector<int>& reduceDims_) const std::vector<int>& reduceDims_)
{ {
using ck::host_reduce::PosUnaryOpFn;
using ck::host_reduce::PreUnaryOpFn;
// this->outLengths = to_int_vector(outDesc.GetLengths()); // this->outLengths = to_int_vector(outDesc.GetLengths());
this->outStrides = outDesc.GetStrides(); this->outStrides = outDesc.GetStrides();
...@@ -171,9 +168,6 @@ struct ReductionHost ...@@ -171,9 +168,6 @@ struct ReductionHost
invariant_dim_indexes.clear(); invariant_dim_indexes.clear();
get_all_indexes<NumInvariantDim>(invariantLengths, invariant_dim_indexes); get_all_indexes<NumInvariantDim>(invariantLengths, invariant_dim_indexes);
}; };
preUnaryOp = PreUnaryOpFn<AccDataType, ReduceOpId>(divider);
posUnaryOp = PosUnaryOpFn<AccDataType, ReduceOpId>(divider);
}; };
void Run(float alpha, void Run(float alpha,
...@@ -182,7 +176,7 @@ struct ReductionHost ...@@ -182,7 +176,7 @@ struct ReductionHost
OutDataType* out_data, OutDataType* out_data,
IndexDataType* out_indices) IndexDataType* out_indices)
{ {
if constexpr(NeedIndices) if constexpr(OutputIndex)
{ {
RunImpl_with_index(alpha, in_data, beta, out_data, out_indices); RunImpl_with_index(alpha, in_data, beta, out_data, out_indices);
} }
...@@ -201,15 +195,17 @@ struct ReductionHost ...@@ -201,15 +195,17 @@ struct ReductionHost
using ck::float_equal_one; using ck::float_equal_one;
using ck::float_equal_zero; using ck::float_equal_zero;
using ck::type_convert; using ck::type_convert;
using ck::host_reduce::binop_with_index_and_nan_check;
using ck::host_reduce::ReduceOpFn2;
using ck::host_reduce::ReduceOpZeroVal;
auto opReduce2 = ReduceOpFn2<AccDataType, ReduceOpId>(); using Accumulation = ck::detail::AccumulateWithIndexAndNanCheck<PropagateNan,
ReduceOperation,
AccDataType,
IndexDataType>;
InElementwiseOperation in_elementwise_op(divider);
AccElementwiseOperation acc_elementwise_op(divider);
if constexpr(NumInvariantDim == 0) if constexpr(NumInvariantDim == 0)
{ {
AccDataType accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>(); AccDataType accuVal = ReduceOperation::GetIdentityValue();
IndexDataType accuIndex = 0; IndexDataType accuIndex = 0;
for(std::size_t i = 0; i < reduce_dim_indexes.size(); i++) for(std::size_t i = 0; i < reduce_dim_indexes.size(); i++)
...@@ -219,15 +215,14 @@ struct ReductionHost ...@@ -219,15 +215,14 @@ struct ReductionHost
auto currVal = type_convert<AccDataType>(in_data[offset_reduce]); auto currVal = type_convert<AccDataType>(in_data[offset_reduce]);
preUnaryOp(currVal); in_elementwise_op(currVal, currVal);
auto currIndex = static_cast<IndexDataType>(i); auto currIndex = static_cast<IndexDataType>(i);
binop_with_index_and_nan_check<AccDataType, IndexDataType, PropagateNan>( Accumulation::Calculate(accuVal, currVal, accuIndex, currIndex);
opReduce2, accuVal, currVal, accuIndex, currIndex);
}; };
posUnaryOp(accuVal); acc_elementwise_op(accuVal, accuVal);
if(!float_equal_one{}(alpha)) if(!float_equal_one{}(alpha))
accuVal *= type_convert<AccDataType>(alpha); accuVal *= type_convert<AccDataType>(alpha);
...@@ -241,7 +236,7 @@ struct ReductionHost ...@@ -241,7 +236,7 @@ struct ReductionHost
else else
{ {
auto thread_reduce_func = [&](auto invariant_index) { auto thread_reduce_func = [&](auto invariant_index) {
AccDataType accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>(); AccDataType accuVal = ReduceOperation::GetIdentityValue();
IndexDataType accuIndex = 0; IndexDataType accuIndex = 0;
auto offset_invariant = auto offset_invariant =
...@@ -255,15 +250,14 @@ struct ReductionHost ...@@ -255,15 +250,14 @@ struct ReductionHost
auto currVal = auto currVal =
type_convert<AccDataType>(in_data[offset_invariant + offset_reduce]); type_convert<AccDataType>(in_data[offset_invariant + offset_reduce]);
preUnaryOp(currVal); in_elementwise_op(currVal, currVal);
auto currIndex = static_cast<IndexDataType>(i); auto currIndex = static_cast<IndexDataType>(i);
binop_with_index_and_nan_check<AccDataType, IndexDataType, PropagateNan>( Accumulation::Calculate(accuVal, currVal, accuIndex, currIndex);
opReduce2, accuVal, currVal, accuIndex, currIndex);
}; };
posUnaryOp(accuVal); acc_elementwise_op(accuVal, accuVal);
if(!float_equal_one{}(alpha)) if(!float_equal_one{}(alpha))
accuVal *= type_convert<AccDataType>(alpha); accuVal *= type_convert<AccDataType>(alpha);
...@@ -308,15 +302,16 @@ struct ReductionHost ...@@ -308,15 +302,16 @@ struct ReductionHost
using ck::float_equal_one; using ck::float_equal_one;
using ck::float_equal_zero; using ck::float_equal_zero;
using ck::type_convert; using ck::type_convert;
using ck::host_reduce::binop_with_nan_check;
using ck::host_reduce::ReduceOpFn;
using ck::host_reduce::ReduceOpZeroVal;
auto opReduce = ReduceOpFn<AccDataType, ReduceOpId>(); using Accumulation =
ck::detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
InElementwiseOperation in_elementwise_op(divider);
AccElementwiseOperation acc_elementwise_op(divider);
if constexpr(NumInvariantDim == 0) if constexpr(NumInvariantDim == 0)
{ {
AccDataType accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>(); AccDataType accuVal = ReduceOperation::GetIdentityValue();
for(const auto& reduce_index : reduce_dim_indexes) for(const auto& reduce_index : reduce_dim_indexes)
{ {
...@@ -325,12 +320,12 @@ struct ReductionHost ...@@ -325,12 +320,12 @@ struct ReductionHost
auto currVal = type_convert<AccDataType>(in_data[offset_reduce]); auto currVal = type_convert<AccDataType>(in_data[offset_reduce]);
preUnaryOp(currVal); in_elementwise_op(currVal, currVal);
binop_with_nan_check<AccDataType, PropagateNan>(opReduce, accuVal, currVal); Accumulation::Calculate(accuVal, currVal);
}; };
posUnaryOp(accuVal); acc_elementwise_op(accuVal, accuVal);
if(!float_equal_one{}(alpha)) if(!float_equal_one{}(alpha))
accuVal *= type_convert<AccDataType>(alpha); accuVal *= type_convert<AccDataType>(alpha);
...@@ -343,7 +338,7 @@ struct ReductionHost ...@@ -343,7 +338,7 @@ struct ReductionHost
else else
{ {
auto thread_reduce_func = [&](auto invariant_index) { auto thread_reduce_func = [&](auto invariant_index) {
AccDataType accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>(); AccDataType accuVal = ReduceOperation::GetIdentityValue();
auto offset_invariant = auto offset_invariant =
get_offset_from_index<NumInvariantDim>(invariantStrides, invariant_index); get_offset_from_index<NumInvariantDim>(invariantStrides, invariant_index);
...@@ -356,12 +351,12 @@ struct ReductionHost ...@@ -356,12 +351,12 @@ struct ReductionHost
auto currVal = auto currVal =
type_convert<AccDataType>(in_data[offset_invariant + offset_reduce]); type_convert<AccDataType>(in_data[offset_invariant + offset_reduce]);
preUnaryOp(currVal); in_elementwise_op(currVal, currVal);
binop_with_nan_check<AccDataType, PropagateNan>(opReduce, accuVal, currVal); Accumulation::Calculate(accuVal, currVal);
}; };
posUnaryOp(accuVal); acc_elementwise_op(accuVal, accuVal);
if(!float_equal_one{}(alpha)) if(!float_equal_one{}(alpha))
accuVal *= type_convert<AccDataType>(alpha); accuVal *= type_convert<AccDataType>(alpha);
......
...@@ -171,8 +171,8 @@ bool profile_batched_gemm_reduce_impl(int do_verification, ...@@ -171,8 +171,8 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
{ {
for(int m = 0; m < M; ++m) for(int m = 0; m < M; ++m)
{ {
float d0_acc = d0_reduce_op.GetReductionZeroVal(); float d0_acc = d0_reduce_op.GetIdentityValue();
float d1_acc = d1_reduce_op.GetReductionZeroVal(); float d1_acc = d1_reduce_op.GetIdentityValue();
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment