// SPDX-License-Identifier: MIT // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include #include #include #include "ck/utility/data_type.hpp" #include "ck/utility/reduction_enums.hpp" #include "ck/utility/reduction_common.hpp" #include "ck/utility/reduction_functions_accumulate.hpp" #include "ck/library/utility/host_common_util.hpp" #include "ck/library/utility/host_tensor.hpp" template static void get_all_indexes(const std::array& dimLengths, std::vector>& indexes) { static_assert(NDim >= 1, "NDim >= 1 is required to use this function!"); if constexpr(NDim == 1) { for(size_t i = 0; i < dimLengths[0]; i++) { std::array index{i}; indexes.push_back(index); }; } else { std::array partial_dim_lengths; for(int i = 0; i < NDim - 1; i++) partial_dim_lengths[i] = dimLengths[i + 1]; std::vector> partial_indexes; get_all_indexes(partial_dim_lengths, partial_indexes); for(size_t i = 0; i < dimLengths[0]; i++) for(const auto& index : partial_indexes) { std::array extIndex; extIndex[0] = i; for(int k = 0; k < NDim - 1; k++) extIndex[k + 1] = index[k]; indexes.push_back(extIndex); }; }; }; template static size_t get_offset_from_index(const std::array& strides, const std::array& index) { size_t offset = 0; for(int i = 0; i < NDim; i++) offset += strides[i] * index[i]; return (offset); }; template static size_t get_offset_from_index(const std::vector& strides, const std::array& index) { size_t offset = 0; for(int i = 0; i < NDim; i++) offset += strides[i] * index[i]; return (offset); }; template struct ReductionHost { using IndexDataType = int32_t; static constexpr int NumInvariantDim = Rank - NumReduceDim; std::vector outStrides; IndexDataType divider; std::array reduceLengths; std::array reduceStrides; std::array invariantLengths; std::array invariantStrides; std::vector> reduce_dim_indexes; std::vector> invariant_dim_indexes; ReductionHost(HostTensorDescriptor& inDesc, HostTensorDescriptor& outDesc, const std::array invariantDims, const std::array reduceDims) { // this->outLengths = to_int_vector(outDesc.GetLengths()); this->outStrides = outDesc.GetStrides(); int product = 1; for(int i = 0; i < NumReduceDim; i++) { reduceLengths[i] = inDesc.GetLengths()[reduceDims[i]]; reduceStrides[i] = inDesc.GetStrides()[reduceDims[i]]; product *= inDesc.GetLengths()[reduceDims[i]]; }; divider = product; for(int i = 0; i < NumInvariantDim; i++) { invariantLengths[i] = inDesc.GetLengths()[invariantDims[i]]; invariantStrides[i] = inDesc.GetStrides()[invariantDims[i]]; }; reduce_dim_indexes.clear(); get_all_indexes(reduceLengths, reduce_dim_indexes); if constexpr(NumInvariantDim > 0) { invariant_dim_indexes.clear(); get_all_indexes(invariantLengths, invariant_dim_indexes); }; }; void Run(float alpha, const InDataType* in_data, float beta, OutDataType* out_data, IndexDataType* out_indices, InElementwiseOperation in_elementwise_op, AccElementwiseOperation acc_elementwise_op) { if constexpr(OutputIndex) { RunImpl_with_index( alpha, in_data, beta, out_data, out_indices, in_elementwise_op, acc_elementwise_op); } else { RunImpl_no_index(alpha, in_data, beta, out_data, in_elementwise_op, acc_elementwise_op); }; }; void RunImpl_with_index(float alpha, const InDataType* in_data, float beta, OutDataType* out_data, IndexDataType* out_indices, InElementwiseOperation in_elementwise_op, AccElementwiseOperation acc_elementwise_op) { using ck::float_equal_one; using ck::float_equal_zero; using ck::type_convert; using Accumulation = ck::detail::AccumulateWithIndexAndNanCheck; if constexpr(NumInvariantDim == 0) { AccDataType accuVal = ReduceOperation::template GetIdentityValue(); IndexDataType accuIndex = 0; for(std::size_t i = 0; i < reduce_dim_indexes.size(); i++) { auto offset_reduce = get_offset_from_index(reduceStrides, reduce_dim_indexes[i]); auto currVal = type_convert(in_data[offset_reduce]); in_elementwise_op(currVal, currVal); auto currIndex = static_cast(i); Accumulation::Calculate(accuVal, currVal, accuIndex, currIndex); }; acc_elementwise_op(accuVal, accuVal); if(!float_equal_one{}(alpha)) accuVal *= type_convert(alpha); if(!float_equal_zero{}(beta)) accuVal += type_convert(out_data[0]) * type_convert(beta); out_data[0] = type_convert(accuVal); out_indices[0] = accuIndex; } else { auto thread_reduce_func = [&](auto invariant_index) { AccDataType accuVal = ReduceOperation::template GetIdentityValue(); IndexDataType accuIndex = 0; auto offset_invariant = get_offset_from_index(invariantStrides, invariant_index); for(std::size_t i = 0; i < reduce_dim_indexes.size(); i++) { auto offset_reduce = get_offset_from_index(reduceStrides, reduce_dim_indexes[i]); auto currVal = type_convert(in_data[offset_invariant + offset_reduce]); in_elementwise_op(currVal, currVal); auto currIndex = static_cast(i); Accumulation::Calculate(accuVal, currVal, accuIndex, currIndex); }; acc_elementwise_op(accuVal, accuVal); if(!float_equal_one{}(alpha)) accuVal *= type_convert(alpha); auto dst_offset = get_offset_from_index(outStrides, invariant_index); if(!float_equal_zero{}(beta)) accuVal += type_convert(out_data[dst_offset]) * type_convert(beta); out_data[dst_offset] = type_convert(accuVal); out_indices[dst_offset] = accuIndex; }; std::size_t num_thread = 1; std::size_t work_per_thread = (invariant_dim_indexes.size() + num_thread - 1) / num_thread; std::vector threads(num_thread); for(std::size_t it = 0; it < num_thread; ++it) { std::size_t iw_begin = it * work_per_thread; std::size_t iw_end = std::min((it + 1) * work_per_thread, invariant_dim_indexes.size()); auto f = [=] { for(std::size_t iw = iw_begin; iw < iw_end; ++iw) { thread_reduce_func(invariant_dim_indexes[iw]); } }; threads[it] = joinable_thread(f); } }; }; void RunImpl_no_index(float alpha, const InDataType* in_data, float beta, OutDataType* out_data, InElementwiseOperation in_elementwise_op, AccElementwiseOperation acc_elementwise_op) { using ck::float_equal_one; using ck::float_equal_zero; using ck::type_convert; using Accumulation = ck::detail::AccumulateWithNanCheck; if constexpr(NumInvariantDim == 0) { AccDataType accuVal = ReduceOperation::template GetIdentityValue(); for(const auto& reduce_index : reduce_dim_indexes) { auto offset_reduce = get_offset_from_index(reduceStrides, reduce_index); auto currVal = type_convert(in_data[offset_reduce]); in_elementwise_op(currVal, currVal); Accumulation::Calculate(accuVal, currVal); }; acc_elementwise_op(accuVal, accuVal); if(!float_equal_one{}(alpha)) accuVal *= type_convert(alpha); if(!float_equal_zero{}(beta)) accuVal += type_convert(out_data[0]) * type_convert(beta); out_data[0] = type_convert(accuVal); } else { auto thread_reduce_func = [&](auto invariant_index) { AccDataType accuVal = ReduceOperation::template GetIdentityValue(); auto offset_invariant = get_offset_from_index(invariantStrides, invariant_index); for(const auto& reduce_index : reduce_dim_indexes) { auto offset_reduce = get_offset_from_index(reduceStrides, reduce_index); auto currVal = type_convert(in_data[offset_invariant + offset_reduce]); in_elementwise_op(currVal, currVal); Accumulation::Calculate(accuVal, currVal); }; acc_elementwise_op(accuVal, accuVal); if(!float_equal_one{}(alpha)) accuVal *= type_convert(alpha); auto dst_offset = get_offset_from_index(outStrides, invariant_index); if(!float_equal_zero{}(beta)) accuVal += type_convert(out_data[dst_offset]) * type_convert(beta); out_data[dst_offset] = type_convert(accuVal); }; std::size_t num_thread = 1; std::size_t work_per_thread = (invariant_dim_indexes.size() + num_thread - 1) / num_thread; std::vector threads(num_thread); for(std::size_t it = 0; it < num_thread; ++it) { std::size_t iw_begin = it * work_per_thread; std::size_t iw_end = std::min((it + 1) * work_per_thread, invariant_dim_indexes.size()); auto f = [=] { for(std::size_t iw = iw_begin; iw < iw_end; ++iw) { thread_reduce_func(invariant_dim_indexes[iw]); } }; threads[it] = joinable_thread(f); } }; }; };