#ifndef REFERENCE_CONV_FWD_BIAS_ACTIVATION_HPP #define REFERENCE_CONV_FWD_BIAS_ACTIVATION_HPP #include #include #include "device_base.hpp" #include "host_tensor.hpp" namespace ck { namespace tensor_operation { namespace host { // out[N, Ho, Wo, K] = // activate(in[N, Hi, Wi, C] * wei[K, Y, X, C] + bias[K]) template struct ReferenceConvFwd_Bias_Activation : public device::BaseOperator { // Argument struct Argument : public device::BaseArgument { Argument(const Tensor& in_n_c_hi_wi, const Tensor& wei_k_c_y_x, Tensor& out_n_k_ho_wo, const Tensor& bias_k, std::vector conv_filter_strides, std::vector conv_filter_dilations, std::vector input_left_pads, std::vector input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op) : in_n_c_hi_wi_{in_n_c_hi_wi}, wei_k_c_y_x_{wei_k_c_y_x}, out_n_k_ho_wo_{out_n_k_ho_wo}, bias_k_{bias_k}, conv_strides_{conv_filter_strides}, conv_dilations_{conv_filter_dilations}, in_left_pads_{input_left_pads}, in_right_pads_{input_right_pads}, in_element_op_{in_element_op}, wei_element_op_{wei_element_op}, out_element_op_{out_element_op} { } const Tensor& in_n_c_hi_wi_; const Tensor& wei_k_c_y_x_; Tensor& out_n_k_ho_wo_; const Tensor& bias_k_; std::vector conv_strides_; std::vector conv_dilations_; std::vector in_left_pads_; std::vector in_right_pads_; InElementwiseOperation in_element_op_; WeiElementwiseOperation wei_element_op_; OutElementwiseOperation out_element_op_; }; // Invoker struct Invoker : public device::BaseInvoker { using Argument = ReferenceConvFwd_Bias_Activation::Argument; float Run(const Argument& arg) { auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { float v = 0; for(int c = 0; c < arg.wei_k_c_y_x_.mDesc.GetLengths()[1]; ++c) { for(int y = 0; y < arg.wei_k_c_y_x_.mDesc.GetLengths()[2]; ++y) { int hi = ho * arg.conv_strides_[0] + y * arg.conv_dilations_[0] - arg.in_left_pads_[0]; for(int x = 0; x < arg.wei_k_c_y_x_.mDesc.GetLengths()[3]; ++x) { int wi = wo * arg.conv_strides_[1] + x * arg.conv_dilations_[1] - arg.in_left_pads_[1]; if(hi >= 0 && hi < arg.in_n_c_hi_wi_.mDesc.GetLengths()[2] && wi >= 0 && wi < arg.in_n_c_hi_wi_.mDesc.GetLengths()[3]) { v += arg.in_element_op_( ck::type_convert(arg.in_n_c_hi_wi_(n, c, hi, wi))) * arg.wei_element_op_( ck::type_convert(arg.wei_k_c_y_x_(k, c, y, x))); } } } } arg.out_n_k_ho_wo_(n, k, ho, wo) = ck::type_convert(arg.out_element_op_(v, arg.bias_k_(k))); }; make_ParallelTensorFunctor(f_nchw, arg.out_n_k_ho_wo_.mDesc.GetLengths()[0], arg.out_n_k_ho_wo_.mDesc.GetLengths()[1], arg.out_n_k_ho_wo_.mDesc.GetLengths()[2], arg.out_n_k_ho_wo_.mDesc.GetLengths()[3])( std::thread::hardware_concurrency()); return 0; } float Run(const device::BaseArgument* p_arg, int) override { return Run(*dynamic_cast(p_arg)); } }; static constexpr bool IsValidCompilationParameter() { // TODO: properly implement this check return true; } bool IsSupportedArgument(const device::BaseArgument*) override { return true; } static auto MakeArgument(const Tensor& in_n_c_hi_wi, const Tensor& wei_k_c_y_x, Tensor& out_n_k_ho_wo, const Tensor& bias_k, std::vector conv_filter_strides, std::vector conv_filter_dilations, std::vector input_left_pads, std::vector input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op) { return Argument{in_n_c_hi_wi, wei_k_c_y_x, out_n_k_ho_wo, bias_k, conv_filter_strides, conv_filter_dilations, input_left_pads, input_right_pads, in_element_op, wei_element_op, out_element_op}; } static auto MakeInvoker() { return Invoker{}; } virtual std::unique_ptr MakeInvokerPointer() { return std::make_unique(Invoker{}); } std::string GetTypeString() const override { auto str = std::stringstream(); // clang-format off str << "ReferenceConvFwd_Bias_Activation" << std::endl; // clang-format on return str.str(); } }; } // namespace host } // namespace tensor_operation } // namespace ck #endif