"driver/src/conv_driver.cpp" did not exist on "7a7fe160866b7b2893be698d77b70cc8cf754fb5"
Unverified Commit 86185bd7 authored by Qianfeng's avatar Qianfeng Committed by GitHub
Browse files

Unify the naming of the math functions used by the host and kernel (#262)

* Use the unified naming for math functions on host and HIP kernel

* Corresponding change/simplification in reduction host/profiler/examples due to unified math functions renaming

* Renaming GetReductionZeroVal() to GetIdentityValue()

* Tiny renaming in profile_reduce_impl.hpp

* More renaming in profile_reduce_impl.hpp

* Replace zeroVal by identiyVal

* Remove ck_ prefix in the naming of ck::math provided functions
parent b6eaf3eb
...@@ -147,8 +147,6 @@ class SimpleAppArgs ...@@ -147,8 +147,6 @@ class SimpleAppArgs
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
using namespace ck::host_reduce;
const std::vector<int> reduceDims{0, 1, 2}; const std::vector<int> reduceDims{0, 1, 2};
const std::vector<int> invariantDims{3}; const std::vector<int> invariantDims{3};
...@@ -254,7 +252,9 @@ int main(int argc, char* argv[]) ...@@ -254,7 +252,9 @@ int main(int argc, char* argv[])
ReductionHost<InDataType, ReductionHost<InDataType,
AccDataType, AccDataType,
OutDataType, OutDataType,
ReduceOpId, ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
Rank, Rank,
NumReduceDim, NumReduceDim,
PropagateNan, PropagateNan,
......
...@@ -108,8 +108,6 @@ int main(int argc, char* argv[]) ...@@ -108,8 +108,6 @@ int main(int argc, char* argv[])
const std::vector<size_t> outLengths = {64, 320, 80}; const std::vector<size_t> outLengths = {64, 320, 80};
using namespace ck::host_reduce;
if(argc == 1) if(argc == 1)
{ {
do_verify = true; do_verify = true;
...@@ -191,7 +189,9 @@ int main(int argc, char* argv[]) ...@@ -191,7 +189,9 @@ int main(int argc, char* argv[])
ReductionHost<InOutDataType, ReductionHost<InOutDataType,
AccDataType, AccDataType,
InOutDataType, InOutDataType,
ReduceOpId, ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
5, // Rank 5, // Rank
2, // NumReduceDim 2, // NumReduceDim
PropagateNan, PropagateNan,
......
...@@ -8,10 +8,12 @@ ...@@ -8,10 +8,12 @@
#include "device.hpp" #include "device.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "host_tensor_generator.hpp" #include "host_tensor_generator.hpp"
#include "host_reduce_util.hpp"
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "tensor_layout.hpp" #include "tensor_layout.hpp"
#include "reduction_enums.hpp" #include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
#include "reduction_functions_accumulate.hpp"
#include "device_pool2d_fwd_nhwc_nhwc.hpp" #include "device_pool2d_fwd_nhwc_nhwc.hpp"
template <typename InDataType, template <typename InDataType,
...@@ -29,19 +31,24 @@ static void pool_host_verify(const Tensor<InDataType>& in, ...@@ -29,19 +31,24 @@ static void pool_host_verify(const Tensor<InDataType>& in,
const std::array<ck::index_t, 2>& in_left_pads, const std::array<ck::index_t, 2>& in_left_pads,
const std::array<ck::index_t, 2>& /*in_right_pads*/) const std::array<ck::index_t, 2>& /*in_right_pads*/)
{ {
using namespace ck::host_reduce;
const int32_t divider = window_spatial_lengths[0] * window_spatial_lengths[1]; const int32_t divider = window_spatial_lengths[0] * window_spatial_lengths[1];
const auto PreUnaryOp = PreUnaryOpFn<AccDataType, ReduceOpId>(divider); using ReduceOperation = typename ck::reduce_binary_operator<AccDataType, ReduceOpId>::opType;
const auto PosUnaryOp = PosUnaryOpFn<AccDataType, ReduceOpId>(divider); using InElementwiseOperation = typename ck::
reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation;
using AccElementwiseOperation = typename ck::
reduce_unary_operator<AccDataType, ReduceOpId, true, true>::AccElementwiseOperation;
const InElementwiseOperation in_elementwise_op(divider);
const AccElementwiseOperation acc_elementwise_op(divider);
if constexpr(!OutputIndex) if constexpr(!OutputIndex)
{ {
auto opReduce = ReduceOpFn<AccDataType, ReduceOpId>(); using Accumulation =
ck::detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
auto f_nchw = [&](auto n, auto c, auto ho, auto wo) { auto f_nchw = [&](auto n, auto c, auto ho, auto wo) {
auto accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>(); auto accuVal = ReduceOperation::GetIdentityValue();
for(ck::index_t y = 0; y < window_spatial_lengths[0]; ++y) for(ck::index_t y = 0; y < window_spatial_lengths[0]; ++y)
{ {
...@@ -54,14 +61,14 @@ static void pool_host_verify(const Tensor<InDataType>& in, ...@@ -54,14 +61,14 @@ static void pool_host_verify(const Tensor<InDataType>& in,
{ {
AccDataType currVal = static_cast<AccDataType>(in(n, c, hi, wi)); AccDataType currVal = static_cast<AccDataType>(in(n, c, hi, wi));
PreUnaryOp(currVal); in_elementwise_op(currVal, currVal);
binop_with_nan_check<AccDataType, PropagateNan>(opReduce, accuVal, currVal); Accumulation::Calculate(accuVal, currVal);
} }
} }
} }
PosUnaryOp(accuVal); acc_elementwise_op(accuVal, accuVal);
out(n, c, ho, wo) = accuVal; out(n, c, ho, wo) = accuVal;
}; };
...@@ -74,10 +81,12 @@ static void pool_host_verify(const Tensor<InDataType>& in, ...@@ -74,10 +81,12 @@ static void pool_host_verify(const Tensor<InDataType>& in,
} }
else else
{ {
auto opReduce = ReduceOpFn2<AccDataType, ReduceOpId>(); using Accumulation = ck::detail::AccumulateWithIndexAndNanCheck<PropagateNan,
ReduceOperation,
auto f_nchw = [&](auto n, auto c, auto ho, auto wo) { AccDataType,
auto accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>(); IndexDataType>;
auto f_nchw = [&](auto n, auto c, auto ho, auto wo) {
auto accuVal = ReduceOperation::GetIdentityValue();
IndexDataType accuIndex = 0; IndexDataType accuIndex = 0;
for(ck::index_t y = 0; y < window_spatial_lengths[0]; ++y) for(ck::index_t y = 0; y < window_spatial_lengths[0]; ++y)
...@@ -92,15 +101,14 @@ static void pool_host_verify(const Tensor<InDataType>& in, ...@@ -92,15 +101,14 @@ static void pool_host_verify(const Tensor<InDataType>& in,
AccDataType currVal = static_cast<AccDataType>(in(n, c, hi, wi)); AccDataType currVal = static_cast<AccDataType>(in(n, c, hi, wi));
IndexDataType currIndex = y * window_spatial_lengths[1] + x; IndexDataType currIndex = y * window_spatial_lengths[1] + x;
PreUnaryOp(currVal); in_elementwise_op(currVal, currVal);
binop_with_index_and_nan_check<AccDataType, IndexDataType, PropagateNan>( Accumulation::Calculate(accuVal, currVal, accuIndex, currIndex);
opReduce, accuVal, currVal, accuIndex, currIndex);
} }
} }
} }
PosUnaryOp(accuVal); acc_elementwise_op(accuVal, accuVal);
out(n, c, ho, wo) = accuVal; out(n, c, ho, wo) = accuVal;
out_indices(n, c, ho, wo) = accuIndex; out_indices(n, c, ho, wo) = accuIndex;
...@@ -139,8 +147,6 @@ bool pool_test(bool do_verification, ...@@ -139,8 +147,6 @@ bool pool_test(bool do_verification,
ck::index_t in_right_pad_h, ck::index_t in_right_pad_h,
ck::index_t in_right_pad_w) ck::index_t in_right_pad_w)
{ {
using namespace ck::host_reduce;
using DevicePoolFwdInstance = using DevicePoolFwdInstance =
ck::tensor_operation::device::DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C< ck::tensor_operation::device::DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C<
InDataType, // InDataType InDataType, // InDataType
......
...@@ -27,8 +27,6 @@ static constexpr bool PropagateNan = false; ...@@ -27,8 +27,6 @@ static constexpr bool PropagateNan = false;
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
using namespace ck::host_reduce;
bool do_verification; bool do_verification;
int init_method; int init_method;
bool time_kernel; bool time_kernel;
......
...@@ -27,8 +27,6 @@ static constexpr bool PropagateNan = false; ...@@ -27,8 +27,6 @@ static constexpr bool PropagateNan = false;
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
using namespace ck::host_reduce;
bool do_verification; bool do_verification;
int init_method; int init_method;
bool time_kernel; bool time_kernel;
......
...@@ -236,7 +236,7 @@ int main(int argc, char* argv[]) ...@@ -236,7 +236,7 @@ int main(int argc, char* argv[])
for(int m = 0; m < M; ++m) for(int m = 0; m < M; ++m)
{ {
ReduceAccDataType d_acc = d_reduce_op.GetReductionZeroVal(); ReduceAccDataType d_acc = d_reduce_op.GetIdentityValue();
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
d_reduce_op(d_acc, c_m_n_host_result(m, n)); d_reduce_op(d_acc, c_m_n_host_result(m, n));
......
...@@ -261,8 +261,8 @@ int main(int argc, char* argv[]) ...@@ -261,8 +261,8 @@ int main(int argc, char* argv[])
for(int m = 0; m < M; ++m) for(int m = 0; m < M; ++m)
{ {
float d0_acc = d0_reduce_op.GetReductionZeroVal(); float d0_acc = d0_reduce_op.GetIdentityValue();
float d1_acc = d1_reduce_op.GetReductionZeroVal(); float d1_acc = d1_reduce_op.GetIdentityValue();
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
......
...@@ -259,8 +259,8 @@ int main(int argc, char* argv[]) ...@@ -259,8 +259,8 @@ int main(int argc, char* argv[])
{ {
for(int m = 0; m < M; ++m) for(int m = 0; m < M; ++m)
{ {
float d0_acc = d0_reduce_op.GetReductionZeroVal(); float d0_acc = d0_reduce_op.GetIdentityValue();
float d1_acc = d1_reduce_op.GetReductionZeroVal(); float d1_acc = d1_reduce_op.GetIdentityValue();
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
......
...@@ -157,8 +157,8 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n, ...@@ -157,8 +157,8 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
auto reduceSumOpInst = ReduceSumOp{}; auto reduceSumOpInst = ReduceSumOp{};
for(int m = 0; m < M; ++m) for(int m = 0; m < M; ++m)
{ {
float mean_acc = reduceSumOpInst.GetReductionZeroVal(); float mean_acc = reduceSumOpInst.GetIdentityValue();
float square_mean_acc = reduceSumOpInst.GetReductionZeroVal(); float square_mean_acc = reduceSumOpInst.GetIdentityValue();
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
......
...@@ -348,8 +348,8 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE ...@@ -348,8 +348,8 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
if constexpr(use_multiblock) if constexpr(use_multiblock)
{ {
const auto zeroVal = const auto identityVal =
ck::reduce::GetReductionZeroValueForInMemoryDataOperation<OutDataType>( ck::reduce::GetIdentityValueueForInMemoryDataOperation<OutDataType>(
OutMemoryDataOperation); OutMemoryDataOperation);
const auto kernel_pre = const auto kernel_pre =
...@@ -362,7 +362,7 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE ...@@ -362,7 +362,7 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InElementwiseOperation, AccE
0, 0,
out_grid_desc_m_2, out_grid_desc_m_2,
arg.out_dev_, arg.out_dev_,
zeroVal); identityVal);
}; };
avg_time += launch_and_time_kernel(stream_config, avg_time += launch_and_time_kernel(stream_config,
......
#pragma once #pragma once
#include "data_type.hpp" #include "data_type.hpp"
#include "math_v2.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -296,7 +297,7 @@ struct UnaryAbs<float, float> ...@@ -296,7 +297,7 @@ struct UnaryAbs<float, float>
{ {
__host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; }; __host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(float& y, const float& x) const { y = abs(x); }; __host__ __device__ void operator()(float& y, const float& x) const { y = ck::math::abs(x); };
}; };
template <> template <>
...@@ -304,7 +305,7 @@ struct UnaryAbs<half_t, half_t> ...@@ -304,7 +305,7 @@ struct UnaryAbs<half_t, half_t>
{ {
__host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; }; __host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(half_t& y, const half_t& x) const { y = __habs(x); }; __host__ __device__ void operator()(half_t& y, const half_t& x) const { y = ck::math::abs(x); };
}; };
template <> template <>
...@@ -312,7 +313,7 @@ struct UnaryAbs<double, double> ...@@ -312,7 +313,7 @@ struct UnaryAbs<double, double>
{ {
__host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; }; __host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(double& y, const double& x) const { y = abs(x); }; __host__ __device__ void operator()(double& y, const double& x) const { y = ck::math::abs(x); };
}; };
template <> template <>
...@@ -320,12 +321,7 @@ struct UnaryAbs<int8_t, int8_t> ...@@ -320,12 +321,7 @@ struct UnaryAbs<int8_t, int8_t>
{ {
__host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; }; __host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(int8_t& y, const int8_t& x) const __host__ __device__ void operator()(int8_t& y, const int8_t& x) const { y = ck::math::abs(x); };
{
int8_t sgn = x >> (8 - 1);
y = (x ^ sgn) - sgn;
};
}; };
template <typename Y, typename X> template <typename Y, typename X>
...@@ -336,7 +332,7 @@ struct UnarySqrt<float, float> ...@@ -336,7 +332,7 @@ struct UnarySqrt<float, float>
{ {
__host__ __device__ UnarySqrt(const int32_t divider = 1) { (void)divider; }; __host__ __device__ UnarySqrt(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(float& y, const float& x) const { y = sqrtf(x); }; __host__ __device__ void operator()(float& y, const float& x) const { y = ck::math::sqrt(x); };
}; };
template <> template <>
...@@ -344,7 +340,10 @@ struct UnarySqrt<double, double> ...@@ -344,7 +340,10 @@ struct UnarySqrt<double, double>
{ {
__host__ __device__ UnarySqrt(const int32_t divider = 1) { (void)divider; }; __host__ __device__ UnarySqrt(const int32_t divider = 1) { (void)divider; };
__host__ __device__ void operator()(double& y, const double& x) const { y = sqrt(x); }; __host__ __device__ void operator()(double& y, const double& x) const
{
y = ck::math::sqrt(x);
};
}; };
} // namespace element_wise } // namespace element_wise
......
...@@ -171,7 +171,7 @@ struct GridwiseReduction_mk_to_m_multiblock ...@@ -171,7 +171,7 @@ struct GridwiseReduction_mk_to_m_multiblock
AccDataType beta, AccDataType beta,
OutDataType* const __restrict__ p_out_value_global) OutDataType* const __restrict__ p_out_value_global)
{ {
const auto zeroVal = ReduceOperation::GetReductionZeroVal(); const auto identityVal = ReduceOperation::GetIdentityValue();
// LDS // LDS
__shared__ AccDataType p_reduce_work_buffer[BlockSize]; __shared__ AccDataType p_reduce_work_buffer[BlockSize];
...@@ -179,7 +179,7 @@ struct GridwiseReduction_mk_to_m_multiblock ...@@ -179,7 +179,7 @@ struct GridwiseReduction_mk_to_m_multiblock
const auto in_global_val_buf = const auto in_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global, make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(), in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(zeroVal)); type_convert<InDataType>(identityVal));
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_value_global, out_grid_desc_m.GetElementSpaceSize()); p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
...@@ -191,7 +191,7 @@ struct GridwiseReduction_mk_to_m_multiblock ...@@ -191,7 +191,7 @@ struct GridwiseReduction_mk_to_m_multiblock
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf; StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; }); static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = identityVal; });
const index_t thread_local_id = get_thread_local_1d_id(); const index_t thread_local_id = get_thread_local_1d_id();
const index_t block_global_id = get_block_1d_id(); const index_t block_global_id = get_block_1d_id();
...@@ -358,12 +358,12 @@ struct GridwiseReduction_mk_to_m_multiblock ...@@ -358,12 +358,12 @@ struct GridwiseReduction_mk_to_m_multiblock
__shared__ AccDataType p_reduce_work_val_buffer[BlockSize]; __shared__ AccDataType p_reduce_work_val_buffer[BlockSize];
__shared__ IndexDataType p_reduce_work_idx_buffer[BlockSize]; __shared__ IndexDataType p_reduce_work_idx_buffer[BlockSize];
const auto zeroVal = ReduceOperation::GetReductionZeroVal(); const auto identityVal = ReduceOperation::GetIdentityValue();
const auto in_global_val_buf = const auto in_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global, make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(), in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(zeroVal)); type_convert<InDataType>(identityVal));
const auto in_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto in_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize()); p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize());
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
...@@ -418,7 +418,7 @@ struct GridwiseReduction_mk_to_m_multiblock ...@@ -418,7 +418,7 @@ struct GridwiseReduction_mk_to_m_multiblock
thread_k_cluster_id * KThreadSliceSize)); thread_k_cluster_id * KThreadSliceSize));
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) = zeroVal; accu_value_buf(I) = identityVal;
accu_index_buf(I) = 0; accu_index_buf(I) = 0;
}); });
...@@ -459,7 +459,7 @@ struct GridwiseReduction_mk_to_m_multiblock ...@@ -459,7 +459,7 @@ struct GridwiseReduction_mk_to_m_multiblock
in_thread_idx_buf); in_thread_idx_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
AccDataType tmpValue = zeroVal; AccDataType tmpValue = identityVal;
IndexDataType tmpIndex = 0; IndexDataType tmpIndex = 0;
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
...@@ -512,7 +512,7 @@ struct GridwiseReduction_mk_to_m_multiblock ...@@ -512,7 +512,7 @@ struct GridwiseReduction_mk_to_m_multiblock
in_thread_val_buf(Number<offset>{})); in_thread_val_buf(Number<offset>{}));
}); });
AccDataType tmpValue = zeroVal; AccDataType tmpValue = identityVal;
IndexDataType tmpIndex = 0; IndexDataType tmpIndex = 0;
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
......
...@@ -135,12 +135,12 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -135,12 +135,12 @@ struct GridwiseReduction_mk_to_m_threadwise
ReduceOperation, ReduceOperation,
PropagateNan>; PropagateNan>;
const auto zeroVal = ReduceOperation::GetReductionZeroVal(); const auto identityVal = ReduceOperation::GetIdentityValue();
const auto in_global_val_buf = const auto in_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global, make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(), in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(zeroVal)); type_convert<InDataType>(identityVal));
auto dst_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto dst_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_value_global, out_grid_desc_m.GetElementSpaceSize()); p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
...@@ -149,7 +149,7 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -149,7 +149,7 @@ struct GridwiseReduction_mk_to_m_threadwise
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf; StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; }); static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = identityVal; });
const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
...@@ -276,12 +276,12 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -276,12 +276,12 @@ struct GridwiseReduction_mk_to_m_threadwise
(void)acc_elementwise_op; (void)acc_elementwise_op;
const auto zeroVal = ReduceOperation::GetReductionZeroVal(); const auto identityVal = ReduceOperation::GetIdentityValue();
const auto in_global_val_buf = const auto in_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global, make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(), in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(zeroVal)); type_convert<InDataType>(identityVal));
const auto in_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto in_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize()); p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize());
...@@ -303,7 +303,7 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -303,7 +303,7 @@ struct GridwiseReduction_mk_to_m_threadwise
StaticBuffer<AddressSpaceEnum::Vgpr, IndexDataType, MThreadSliceSize, true> accu_index_buf; StaticBuffer<AddressSpaceEnum::Vgpr, IndexDataType, MThreadSliceSize, true> accu_index_buf;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
accu_value_buf(I) = zeroVal; accu_value_buf(I) = identityVal;
accu_index_buf(I) = 0; accu_index_buf(I) = 0;
}); });
......
...@@ -816,10 +816,10 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -816,10 +816,10 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
false>; false>;
// Global write Gemm shuffle + reduction // Global write Gemm shuffle + reduction
const auto d_zeroVal = DReduceOperation::GetReductionZeroVal(); const auto d_identityVal = DReduceOperation::GetIdentityValue();
static_for<0, mreduce_per_thread, 1>{}( static_for<0, mreduce_per_thread, 1>{}(
[&](auto I) { d_thread_buf(I) = d_zeroVal; }); [&](auto I) { d_thread_buf(I) = d_identityVal; });
// reduce in VGPR // reduce in VGPR
static_for<0, mreduce_per_thread, 1>{}([&](auto im) { static_for<0, mreduce_per_thread, 1>{}([&](auto im) {
......
...@@ -3,11 +3,13 @@ ...@@ -3,11 +3,13 @@
#include <cmath> #include <cmath>
#include "data_type.hpp" #include "data_type.hpp"
#include "half.hpp" #include "type.hpp"
namespace ck { namespace ck {
namespace math { namespace math {
// math functions for the host, some are implemented by calling C++ std functions
static inline __host__ float abs(float x) { return std::abs(x); }; static inline __host__ float abs(float x) { return std::abs(x); };
static inline __host__ double abs(double x) { return std::abs(x); }; static inline __host__ double abs(double x) { return std::abs(x); };
...@@ -28,26 +30,26 @@ static inline __host__ int32_t abs(int32_t x) ...@@ -28,26 +30,26 @@ static inline __host__ int32_t abs(int32_t x)
static inline __host__ half_t abs(half_t x) static inline __host__ half_t abs(half_t x)
{ {
half_float::half xx = *reinterpret_cast<half_float::half*>(&x); uint16_t xx = ck::bit_cast<uint16_t>(x);
half_float::half abs_xx = half_float::abs(xx); uint16_t abs_xx = xx & 0x7fff;
half_t abs_x = *reinterpret_cast<half_t*>(&abs_xx); half_t abs_x = ck::bit_cast<half_t>(abs_xx);
return abs_x; return abs_x;
}; };
static inline __host__ float isnan(float x) { return std::isnan(x); }; static inline __host__ bool isnan(float x) { return std::isnan(x); };
static inline __host__ double isnan(double x) { return std::isnan(x); }; static inline __host__ bool isnan(double x) { return std::isnan(x); };
static inline __host__ int8_t isnan(int8_t x) static inline __host__ bool isnan(int8_t x)
{ {
(void)x; (void)x;
return false; return false;
}; };
static inline __host__ int32_t isnan(int32_t x) static inline __host__ bool isnan(int32_t x)
{ {
(void)x; (void)x;
return false; return false;
...@@ -55,11 +57,59 @@ static inline __host__ int32_t isnan(int32_t x) ...@@ -55,11 +57,59 @@ static inline __host__ int32_t isnan(int32_t x)
static inline __host__ bool isnan(half_t x) static inline __host__ bool isnan(half_t x)
{ {
half_float::half xx = *reinterpret_cast<half_float::half*>(&x); uint16_t xx = ck::bit_cast<uint16_t>(x);
return (xx & 0x7FFF) > 0x7C00;
};
static inline __host__ float sqrt(float x) { return std::sqrt(x); };
static inline __host__ double sqrt(double x) { return std::sqrt(x); };
// math functions for the HIP kernel, some are implemented by calling hip builtin functions
static inline __device__ float abs(float x) { return ::abs(x); };
static inline __device__ double abs(double x) { return ::abs(x); };
static inline __device__ int8_t abs(int8_t x)
{
int8_t sgn = x >> (8 - 1);
return (x ^ sgn) - sgn;
};
static inline __device__ int32_t abs(int32_t x)
{
int32_t sgn = x >> (32 - 1);
return (x ^ sgn) - sgn;
};
static inline __device__ half_t abs(half_t x) { return ::__habs(x); };
static inline __device__ bool isnan(float x) { return ::isnan(x); };
static inline __device__ bool isnan(double x) { return ::isnan(x); };
static inline __device__ bool isnan(int8_t x)
{
(void)x;
return false;
};
return half_float::isnan(xx); static inline __device__ bool isnan(int32_t x)
{
(void)x;
return false;
}; };
static inline __device__ bool isnan(half_t x) { return ::__hisnan(x); };
static inline __device__ float sqrt(float x) { return ::sqrtf(x); };
static inline __device__ double sqrt(double x) { return ::sqrt(x); };
} // namespace math } // namespace math
} // namespace ck } // namespace ck
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#define CK_REDUCTION_FUNCTIONS_BINOP_HPP #define CK_REDUCTION_FUNCTIONS_BINOP_HPP
#include "data_type.hpp" #include "data_type.hpp"
#include "math_v2.hpp"
#include "reduction_common.hpp" #include "reduction_common.hpp"
#include "reduction_operator.hpp" #include "reduction_operator.hpp"
...@@ -34,18 +35,6 @@ ...@@ -34,18 +35,6 @@
namespace ck { namespace ck {
namespace detail { namespace detail {
template <typename T>
static inline __device__ bool is_nan(T x)
{
return (isnan(x));
};
template <>
inline __device__ bool is_nan<half_t>(half_t x)
{
return (__hisnan(x));
};
template <bool PropagateNan, typename ReduceOperation, typename AccDataType> template <bool PropagateNan, typename ReduceOperation, typename AccDataType>
struct AccumulateWithNanCheck; struct AccumulateWithNanCheck;
...@@ -53,7 +42,7 @@ template <typename ReduceOperation, typename AccDataType> ...@@ -53,7 +42,7 @@ template <typename ReduceOperation, typename AccDataType>
struct AccumulateWithNanCheck<false, ReduceOperation, AccDataType> struct AccumulateWithNanCheck<false, ReduceOperation, AccDataType>
{ {
// cppcheck-suppress constParameter // cppcheck-suppress constParameter
__device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal) __host__ __device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal)
{ {
ReduceOperation{}(accuVal, currVal); ReduceOperation{}(accuVal, currVal);
}; };
...@@ -62,9 +51,11 @@ struct AccumulateWithNanCheck<false, ReduceOperation, AccDataType> ...@@ -62,9 +51,11 @@ struct AccumulateWithNanCheck<false, ReduceOperation, AccDataType>
template <typename ReduceOperation, typename AccDataType> template <typename ReduceOperation, typename AccDataType>
struct AccumulateWithNanCheck<true, ReduceOperation, AccDataType> struct AccumulateWithNanCheck<true, ReduceOperation, AccDataType>
{ {
__device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal) __host__ __device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal)
{ {
if(is_nan(currVal)) using ck::math::isnan;
if(isnan(currVal))
{ {
accuVal = currVal; accuVal = currVal;
} }
...@@ -81,7 +72,7 @@ struct AccumulateWithIndexAndNanCheck; ...@@ -81,7 +72,7 @@ struct AccumulateWithIndexAndNanCheck;
template <typename ReduceOperation, typename AccDataType, typename IndexDataType> template <typename ReduceOperation, typename AccDataType, typename IndexDataType>
struct AccumulateWithIndexAndNanCheck<false, ReduceOperation, AccDataType, IndexDataType> struct AccumulateWithIndexAndNanCheck<false, ReduceOperation, AccDataType, IndexDataType>
{ {
__device__ static inline void __host__ __device__ static inline void
// cppcheck-suppress constParameter // cppcheck-suppress constParameter
Calculate(AccDataType& accuVal, Calculate(AccDataType& accuVal,
AccDataType currVal, AccDataType currVal,
...@@ -101,12 +92,14 @@ template <typename ReduceOperation, typename AccDataType, typename IndexDataType ...@@ -101,12 +92,14 @@ template <typename ReduceOperation, typename AccDataType, typename IndexDataType
struct AccumulateWithIndexAndNanCheck<true, ReduceOperation, AccDataType, IndexDataType> struct AccumulateWithIndexAndNanCheck<true, ReduceOperation, AccDataType, IndexDataType>
{ {
// The method is called when the ReduceOperation is indexable and the user asked for indices // The method is called when the ReduceOperation is indexable and the user asked for indices
__device__ static inline void Calculate(AccDataType& accuVal, __host__ __device__ static inline void Calculate(AccDataType& accuVal,
AccDataType currVal, AccDataType currVal,
IndexDataType& accuIndex, IndexDataType& accuIndex,
IndexDataType currIndex) IndexDataType currIndex)
{ {
if(is_nan(currVal)) using ck::math::isnan;
if(isnan(currVal))
{ {
accuVal = currVal; accuVal = currVal;
accuIndex = currIndex; accuIndex = currIndex;
......
...@@ -36,7 +36,7 @@ namespace reduce { ...@@ -36,7 +36,7 @@ namespace reduce {
// Every binary operator used in reduction is represented by a templated functor class. Each functor // Every binary operator used in reduction is represented by a templated functor class. Each functor
// class must provide at least // class must provide at least
// three members: // three members:
// 1) GetReductionZeroVal() -- the interface to return the "identity element" for the binary // 1) GetIdentityValue() -- the interface to return the "identity element" for the binary
// operator, "identity element" is the unique // operator, "identity element" is the unique
// element in the algebraic space that doesn't affect the value of other elements // element in the algebraic space that doesn't affect the value of other elements
// when operated against them, and the concept is similar to zero vector in // when operated against them, and the concept is similar to zero vector in
...@@ -59,7 +59,7 @@ struct Add ...@@ -59,7 +59,7 @@ struct Add
{ {
using dataType = T; using dataType = T;
__host__ __device__ static constexpr T GetReductionZeroVal() { return static_cast<T>(0.0f); }; __host__ __device__ static constexpr T GetIdentityValue() { return static_cast<T>(0.0f); };
__device__ static constexpr bool __device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation) IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
...@@ -76,7 +76,7 @@ struct Mul ...@@ -76,7 +76,7 @@ struct Mul
{ {
using dataType = T; using dataType = T;
__host__ __device__ static constexpr T GetReductionZeroVal() { return static_cast<T>(1.0f); }; __host__ __device__ static constexpr T GetIdentityValue() { return static_cast<T>(1.0f); };
__device__ static constexpr bool __device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation) IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
...@@ -92,7 +92,7 @@ struct Max ...@@ -92,7 +92,7 @@ struct Max
{ {
using dataType = T; using dataType = T;
__host__ __device__ static constexpr T GetReductionZeroVal() __host__ __device__ static constexpr T GetIdentityValue()
{ {
return NumericLimits<T>::Lowest(); return NumericLimits<T>::Lowest();
}; };
...@@ -125,10 +125,7 @@ struct Min ...@@ -125,10 +125,7 @@ struct Min
{ {
using dataType = T; using dataType = T;
__host__ __device__ static constexpr T GetReductionZeroVal() __host__ __device__ static constexpr T GetIdentityValue() { return NumericLimits<T>::Max(); };
{
return NumericLimits<T>::Max();
};
__device__ static constexpr bool __device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation) IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
...@@ -158,7 +155,7 @@ struct AMax ...@@ -158,7 +155,7 @@ struct AMax
{ {
using dataType = T; using dataType = T;
__host__ __device__ static constexpr T GetReductionZeroVal() { return static_cast<T>(0.0f); }; __host__ __device__ static constexpr T GetIdentityValue() { return static_cast<T>(0.0f); };
__device__ static constexpr bool __device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation) IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
...@@ -184,7 +181,7 @@ struct AMax ...@@ -184,7 +181,7 @@ struct AMax
}; };
template <typename T> template <typename T>
T GetReductionZeroValueForInMemoryDataOperation(InMemoryDataOperationEnum operation) T GetIdentityValueueForInMemoryDataOperation(InMemoryDataOperationEnum operation)
{ {
T result = ck::type_convert<T>(0.0f); T result = ck::type_convert<T>(0.0f);
......
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2020 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef GUARD_HOST_REDUCE_UTIL_HPP
#define GUARD_HOST_REDUCE_UTIL_HPP
#include <limits>
#include <cmath>
#include <functional>
#include "reduction_enums.hpp"
#include "data_type.hpp"
#include "math_v2.hpp"
namespace ck {
namespace host_reduce {
using ck::NanPropagation;
using ck::ReduceTensorOp;
template <typename AccDataType, ReduceTensorOp ReduceOpId>
__host__ static inline std::function<void(AccDataType&)> PreUnaryOpFn(int)
{
using ck::math::abs;
if constexpr(ReduceOpId == ReduceTensorOp::NORM1)
{
return ([&](AccDataType& a_) { a_ = abs(a_); });
}
else if constexpr(ReduceOpId == ReduceTensorOp::NORM2)
{
return ([&](AccDataType& a_) { a_ = a_ * a_; });
}
else if constexpr(ReduceOpId == ReduceTensorOp::AMAX)
{
return ([&](AccDataType& a_) { a_ = abs(a_); });
}
else
{
// ReduceTensorOp::AVG:
// ReduceTensorOp::ADD:
// ReduceTensorOp::MUL:
// ReduceTensorOp::MIN:
// ReduceTensorOp::MAX:
return ([&](AccDataType&) {});
};
};
template <typename AccDataType, ReduceTensorOp ReduceOpId>
__host__ static inline std::function<void(AccDataType&)> PosUnaryOpFn(int32_t divider)
{
using std::sqrt;
if constexpr(ReduceOpId == ReduceTensorOp::NORM2)
{
return ([&](AccDataType& a_) { a_ = sqrt(a_); });
}
else if constexpr(ReduceOpId == ReduceTensorOp::AVG)
{
return ([&, divider](AccDataType& a_) {
a_ = a_ / static_cast<AccDataType>(static_cast<float>(divider));
});
}
else
{
// ReduceTensorOp::ADD:
// ReduceTensorOp::NORM1:
// ReduceTensorOp::MUL:
// ReduceTensorOp::MIN:
// ReduceTensorOp::MAX:
// ReduceTensorOp::AMAX:
return ([&](AccDataType&) {});
}
};
template <typename AccDataType, ReduceTensorOp ReduceOpId>
__host__ static inline std::function<void(AccDataType&, AccDataType)> ReduceOpFn()
{
if constexpr(ReduceOpId == ReduceTensorOp::ADD || ReduceOpId == ReduceTensorOp::AVG ||
ReduceOpId == ReduceTensorOp::NORM1 || ReduceOpId == ReduceTensorOp::NORM2)
{
return ([&](AccDataType& a_, AccDataType b_) { a_ = a_ + b_; });
}
else if constexpr(ReduceOpId == ReduceTensorOp::MUL)
{
return ([&](AccDataType& a_, AccDataType b_) { a_ = a_ * b_; });
}
else if constexpr(ReduceOpId == ReduceTensorOp::MIN)
{
return ([&](AccDataType& a_, AccDataType b_) {
if(a_ > b_)
a_ = b_;
});
}
else if constexpr(ReduceOpId == ReduceTensorOp::MAX || ReduceOpId == ReduceTensorOp::AMAX)
{
return ([&](AccDataType& a_, AccDataType b_) {
if(a_ < b_)
a_ = b_;
});
}
};
template <typename AccDataType, ReduceTensorOp ReduceOpId>
__host__ static inline std::function<void(AccDataType&, AccDataType, bool& changed)> ReduceOpFn2()
{
if constexpr(ReduceOpId == ReduceTensorOp::MIN)
{
return ([&](AccDataType& a_, AccDataType b_, bool& changed) {
if(a_ > b_)
{
a_ = b_;
changed = true;
}
else
changed = false;
});
}
else if constexpr(ReduceOpId == ReduceTensorOp::MAX || ReduceOpId == ReduceTensorOp::AMAX)
{
return ([&](AccDataType& a_, AccDataType b_, bool& changed) {
if(a_ < b_)
{
a_ = b_;
changed = true;
}
else
changed = false;
});
}
else
{
// ReduceTensorOp::ADD:
// ReduceTensorOp::MUL:
// ReduceTensorOp::AVG:
// ReduceTensorOp::NORM1:
// ReduceTensorOp::NORM2:
return (std::function<void(AccDataType&, AccDataType, bool&)>{});
};
};
template <typename AccDataType, ReduceTensorOp ReduceOpId>
__host__ static inline AccDataType ReduceOpZeroVal()
{
if constexpr(ReduceOpId == ReduceTensorOp::MUL)
{
return (static_cast<AccDataType>(1.0f));
}
else if constexpr(ReduceOpId == ReduceTensorOp::MIN)
{
return (ck::NumericLimits<AccDataType>::Max());
}
else if constexpr(ReduceOpId == ReduceTensorOp::MAX)
{
return (ck::NumericLimits<AccDataType>::Lowest());
}
else if constexpr(ReduceOpId == ReduceTensorOp::AMAX)
{
return (static_cast<AccDataType>(0.0f));
}
else
{
// ReduceTensorOp::ADD
// ReduceTensorOp::AVG
// ReduceTensorOp::NORM1
// ReduceTensorOp::NORM2
return (static_cast<AccDataType>(0.0f));
};
};
template <typename AccDataType, bool PropagateNan>
__host__ static inline void
binop_with_nan_check(std::function<void(AccDataType&, AccDataType)> opReduce,
AccDataType& accuVal,
AccDataType currVal)
{
using ck::math::isnan;
if constexpr(!PropagateNan)
{
opReduce(accuVal, currVal);
}
else
{
if(isnan(currVal))
accuVal = currVal;
else
opReduce(accuVal, currVal);
};
};
template <typename AccDataType, typename IndexDataType, bool PropagateNan>
__host__ static inline void
binop_with_index_and_nan_check(std::function<void(AccDataType&, AccDataType, bool&)> opReduce,
AccDataType& accuVal,
AccDataType currVal,
IndexDataType& accuIndex,
IndexDataType currIndex)
{
using ck::math::isnan;
if constexpr(!PropagateNan)
{
bool changed;
opReduce(accuVal, currVal, changed);
if(changed)
accuIndex = currIndex;
}
else
{
if(isnan(currVal))
{
accuVal = currVal;
accuIndex = currIndex;
}
else
{
bool changed;
opReduce(accuVal, currVal, changed);
if(changed)
accuIndex = currIndex;
};
};
};
}; // namespace host_reduce
}; // namespace ck
#endif
...@@ -33,10 +33,10 @@ ...@@ -33,10 +33,10 @@
#include "reduction_enums.hpp" #include "reduction_enums.hpp"
#include "reduction_common.hpp" #include "reduction_common.hpp"
#include "host_reduce_util.hpp"
#include "host_common_util.hpp" #include "host_common_util.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "data_type.hpp" #include "data_type.hpp"
#include "reduction_functions_accumulate.hpp"
template <int NDim> template <int NDim>
static void get_all_indexes(const std::array<size_t, NDim>& dimLengths, static void get_all_indexes(const std::array<size_t, NDim>& dimLengths,
...@@ -106,11 +106,13 @@ static size_t get_offset_from_index(const std::vector<size_t>& strides, ...@@ -106,11 +106,13 @@ static size_t get_offset_from_index(const std::vector<size_t>& strides,
template <typename InDataType, template <typename InDataType,
typename AccDataType, typename AccDataType,
typename OutDataType, typename OutDataType,
ck::ReduceTensorOp ReduceOpId, typename ReduceOperation,
typename InElementwiseOperation,
typename AccElementwiseOperation,
int Rank, int Rank,
int NumReduceDim, int NumReduceDim,
bool PropagateNan, bool PropagateNan,
bool NeedIndices> bool OutputIndex>
struct ReductionHost struct ReductionHost
{ {
using IndexDataType = int32_t; using IndexDataType = int32_t;
...@@ -122,8 +124,6 @@ struct ReductionHost ...@@ -122,8 +124,6 @@ struct ReductionHost
std::vector<int> reduceDims; std::vector<int> reduceDims;
IndexDataType divider; IndexDataType divider;
std::function<void(AccDataType&)> preUnaryOp;
std::function<void(AccDataType&)> posUnaryOp;
std::array<size_t, NumReduceDim> reduceLengths; std::array<size_t, NumReduceDim> reduceLengths;
std::array<size_t, NumReduceDim> reduceStrides; std::array<size_t, NumReduceDim> reduceStrides;
std::array<size_t, NumInvariantDim> invariantLengths; std::array<size_t, NumInvariantDim> invariantLengths;
...@@ -137,9 +137,6 @@ struct ReductionHost ...@@ -137,9 +137,6 @@ struct ReductionHost
const std::vector<int>& invariantDims_, const std::vector<int>& invariantDims_,
const std::vector<int>& reduceDims_) const std::vector<int>& reduceDims_)
{ {
using ck::host_reduce::PosUnaryOpFn;
using ck::host_reduce::PreUnaryOpFn;
// this->outLengths = to_int_vector(outDesc.GetLengths()); // this->outLengths = to_int_vector(outDesc.GetLengths());
this->outStrides = outDesc.GetStrides(); this->outStrides = outDesc.GetStrides();
...@@ -171,9 +168,6 @@ struct ReductionHost ...@@ -171,9 +168,6 @@ struct ReductionHost
invariant_dim_indexes.clear(); invariant_dim_indexes.clear();
get_all_indexes<NumInvariantDim>(invariantLengths, invariant_dim_indexes); get_all_indexes<NumInvariantDim>(invariantLengths, invariant_dim_indexes);
}; };
preUnaryOp = PreUnaryOpFn<AccDataType, ReduceOpId>(divider);
posUnaryOp = PosUnaryOpFn<AccDataType, ReduceOpId>(divider);
}; };
void Run(float alpha, void Run(float alpha,
...@@ -182,7 +176,7 @@ struct ReductionHost ...@@ -182,7 +176,7 @@ struct ReductionHost
OutDataType* out_data, OutDataType* out_data,
IndexDataType* out_indices) IndexDataType* out_indices)
{ {
if constexpr(NeedIndices) if constexpr(OutputIndex)
{ {
RunImpl_with_index(alpha, in_data, beta, out_data, out_indices); RunImpl_with_index(alpha, in_data, beta, out_data, out_indices);
} }
...@@ -201,15 +195,17 @@ struct ReductionHost ...@@ -201,15 +195,17 @@ struct ReductionHost
using ck::float_equal_one; using ck::float_equal_one;
using ck::float_equal_zero; using ck::float_equal_zero;
using ck::type_convert; using ck::type_convert;
using ck::host_reduce::binop_with_index_and_nan_check;
using ck::host_reduce::ReduceOpFn2;
using ck::host_reduce::ReduceOpZeroVal;
auto opReduce2 = ReduceOpFn2<AccDataType, ReduceOpId>(); using Accumulation = ck::detail::AccumulateWithIndexAndNanCheck<PropagateNan,
ReduceOperation,
AccDataType,
IndexDataType>;
InElementwiseOperation in_elementwise_op(divider);
AccElementwiseOperation acc_elementwise_op(divider);
if constexpr(NumInvariantDim == 0) if constexpr(NumInvariantDim == 0)
{ {
AccDataType accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>(); AccDataType accuVal = ReduceOperation::GetIdentityValue();
IndexDataType accuIndex = 0; IndexDataType accuIndex = 0;
for(std::size_t i = 0; i < reduce_dim_indexes.size(); i++) for(std::size_t i = 0; i < reduce_dim_indexes.size(); i++)
...@@ -219,15 +215,14 @@ struct ReductionHost ...@@ -219,15 +215,14 @@ struct ReductionHost
auto currVal = type_convert<AccDataType>(in_data[offset_reduce]); auto currVal = type_convert<AccDataType>(in_data[offset_reduce]);
preUnaryOp(currVal); in_elementwise_op(currVal, currVal);
auto currIndex = static_cast<IndexDataType>(i); auto currIndex = static_cast<IndexDataType>(i);
binop_with_index_and_nan_check<AccDataType, IndexDataType, PropagateNan>( Accumulation::Calculate(accuVal, currVal, accuIndex, currIndex);
opReduce2, accuVal, currVal, accuIndex, currIndex);
}; };
posUnaryOp(accuVal); acc_elementwise_op(accuVal, accuVal);
if(!float_equal_one{}(alpha)) if(!float_equal_one{}(alpha))
accuVal *= type_convert<AccDataType>(alpha); accuVal *= type_convert<AccDataType>(alpha);
...@@ -241,7 +236,7 @@ struct ReductionHost ...@@ -241,7 +236,7 @@ struct ReductionHost
else else
{ {
auto thread_reduce_func = [&](auto invariant_index) { auto thread_reduce_func = [&](auto invariant_index) {
AccDataType accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>(); AccDataType accuVal = ReduceOperation::GetIdentityValue();
IndexDataType accuIndex = 0; IndexDataType accuIndex = 0;
auto offset_invariant = auto offset_invariant =
...@@ -255,15 +250,14 @@ struct ReductionHost ...@@ -255,15 +250,14 @@ struct ReductionHost
auto currVal = auto currVal =
type_convert<AccDataType>(in_data[offset_invariant + offset_reduce]); type_convert<AccDataType>(in_data[offset_invariant + offset_reduce]);
preUnaryOp(currVal); in_elementwise_op(currVal, currVal);
auto currIndex = static_cast<IndexDataType>(i); auto currIndex = static_cast<IndexDataType>(i);
binop_with_index_and_nan_check<AccDataType, IndexDataType, PropagateNan>( Accumulation::Calculate(accuVal, currVal, accuIndex, currIndex);
opReduce2, accuVal, currVal, accuIndex, currIndex);
}; };
posUnaryOp(accuVal); acc_elementwise_op(accuVal, accuVal);
if(!float_equal_one{}(alpha)) if(!float_equal_one{}(alpha))
accuVal *= type_convert<AccDataType>(alpha); accuVal *= type_convert<AccDataType>(alpha);
...@@ -308,15 +302,16 @@ struct ReductionHost ...@@ -308,15 +302,16 @@ struct ReductionHost
using ck::float_equal_one; using ck::float_equal_one;
using ck::float_equal_zero; using ck::float_equal_zero;
using ck::type_convert; using ck::type_convert;
using ck::host_reduce::binop_with_nan_check;
using ck::host_reduce::ReduceOpFn;
using ck::host_reduce::ReduceOpZeroVal;
auto opReduce = ReduceOpFn<AccDataType, ReduceOpId>(); using Accumulation =
ck::detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
InElementwiseOperation in_elementwise_op(divider);
AccElementwiseOperation acc_elementwise_op(divider);
if constexpr(NumInvariantDim == 0) if constexpr(NumInvariantDim == 0)
{ {
AccDataType accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>(); AccDataType accuVal = ReduceOperation::GetIdentityValue();
for(const auto& reduce_index : reduce_dim_indexes) for(const auto& reduce_index : reduce_dim_indexes)
{ {
...@@ -325,12 +320,12 @@ struct ReductionHost ...@@ -325,12 +320,12 @@ struct ReductionHost
auto currVal = type_convert<AccDataType>(in_data[offset_reduce]); auto currVal = type_convert<AccDataType>(in_data[offset_reduce]);
preUnaryOp(currVal); in_elementwise_op(currVal, currVal);
binop_with_nan_check<AccDataType, PropagateNan>(opReduce, accuVal, currVal); Accumulation::Calculate(accuVal, currVal);
}; };
posUnaryOp(accuVal); acc_elementwise_op(accuVal, accuVal);
if(!float_equal_one{}(alpha)) if(!float_equal_one{}(alpha))
accuVal *= type_convert<AccDataType>(alpha); accuVal *= type_convert<AccDataType>(alpha);
...@@ -343,7 +338,7 @@ struct ReductionHost ...@@ -343,7 +338,7 @@ struct ReductionHost
else else
{ {
auto thread_reduce_func = [&](auto invariant_index) { auto thread_reduce_func = [&](auto invariant_index) {
AccDataType accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>(); AccDataType accuVal = ReduceOperation::GetIdentityValue();
auto offset_invariant = auto offset_invariant =
get_offset_from_index<NumInvariantDim>(invariantStrides, invariant_index); get_offset_from_index<NumInvariantDim>(invariantStrides, invariant_index);
...@@ -356,12 +351,12 @@ struct ReductionHost ...@@ -356,12 +351,12 @@ struct ReductionHost
auto currVal = auto currVal =
type_convert<AccDataType>(in_data[offset_invariant + offset_reduce]); type_convert<AccDataType>(in_data[offset_invariant + offset_reduce]);
preUnaryOp(currVal); in_elementwise_op(currVal, currVal);
binop_with_nan_check<AccDataType, PropagateNan>(opReduce, accuVal, currVal); Accumulation::Calculate(accuVal, currVal);
}; };
posUnaryOp(accuVal); acc_elementwise_op(accuVal, accuVal);
if(!float_equal_one{}(alpha)) if(!float_equal_one{}(alpha))
accuVal *= type_convert<AccDataType>(alpha); accuVal *= type_convert<AccDataType>(alpha);
......
...@@ -171,8 +171,8 @@ bool profile_batched_gemm_reduce_impl(int do_verification, ...@@ -171,8 +171,8 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
{ {
for(int m = 0; m < M; ++m) for(int m = 0; m < M; ++m)
{ {
float d0_acc = d0_reduce_op.GetReductionZeroVal(); float d0_acc = d0_reduce_op.GetIdentityValue();
float d1_acc = d1_reduce_op.GetReductionZeroVal(); float d1_acc = d1_reduce_op.GetIdentityValue();
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment