#ifndef REFERENCE_CONV_FWD_HPP #define REFERENCE_CONV_FWD_HPP #include #include #include #include "device_base.hpp" #include "host_tensor.hpp" namespace ck { namespace tensor_operation { namespace host { // // @brief Reference implementation for forward convolution. // // @paragraph Supported tensor layouts. Input tensor supports NCHiWi data layout. // Weights tensor supports KCYX data layout. Output tensor supports // NKHoWo data layout. // // @tparam InDataType Input tensor data type. // @tparam WeiDataType Weights tensor data type. // @tparam OutDataType Output tensor data type. // @tparam InElementwiseOperation Functor for input tensor elementwise // operation. // @tparam WeiElementwiseOperation Functor for weights tensor elementwise // operation. // @tparam NumDimSpatial Number of spatial dimensions. // template = 1 && NumDimSpatial <= 3, bool>::type = false> struct ReferenceConvFwd : public device::BaseOperator { // Argument struct Argument : public device::BaseArgument { Argument(const Tensor& input, const Tensor& weight, Tensor& output, std::vector conv_filter_strides, std::vector conv_filter_dilations, std::vector input_left_pads, std::vector input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op) : input_{input}, weight_{weight}, output_{output}, conv_strides_{conv_filter_strides}, conv_dilations_{conv_filter_dilations}, in_left_pads_{input_left_pads}, in_right_pads_{input_right_pads}, in_element_op_{in_element_op}, wei_element_op_{wei_element_op}, out_element_op_{out_element_op} { } const Tensor& input_; const Tensor& weight_; Tensor& output_; std::vector conv_strides_; std::vector conv_dilations_; std::vector in_left_pads_; std::vector in_right_pads_; InElementwiseOperation in_element_op_; WeiElementwiseOperation wei_element_op_; OutElementwiseOperation out_element_op_; }; struct Invoker : public device::BaseInvoker { using Argument = ReferenceConvFwd::Argument; float Run(const Argument& arg) { if constexpr(NumDimSpatial == 1) { auto f_ncw = [&](auto n, auto k, auto wo) { float v_acc = 0; for(int c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c) { for(int x = 0; x < arg.weight_.mDesc.GetLengths()[2]; ++x) { int wi = wo * arg.conv_strides_[0] + x * arg.conv_dilations_[0] - arg.in_left_pads_[0]; if(wi >= 0 && wi < arg.input_.mDesc.GetLengths()[2]) { float v_in; float v_wei; arg.in_element_op_(v_in, static_cast(arg.input_(n, c, wi))); arg.wei_element_op_(v_wei, static_cast(arg.weight_(k, c, x))); v_acc += v_in * v_wei; } } } float v_out; arg.out_element_op_(v_out, v_acc); arg.output_(n, k, wo) = v_out; }; make_ParallelTensorFunctor(f_ncw, arg.output_.mDesc.GetLengths()[0], arg.output_.mDesc.GetLengths()[1], arg.output_.mDesc.GetLengths()[2])( std::thread::hardware_concurrency()); return 0; } else if constexpr(NumDimSpatial == 2) { auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { float v_acc = 0; for(int c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c) { for(int y = 0; y < arg.weight_.mDesc.GetLengths()[2]; ++y) { int hi = ho * arg.conv_strides_[0] + y * arg.conv_dilations_[0] - arg.in_left_pads_[0]; for(int x = 0; x < arg.weight_.mDesc.GetLengths()[3]; ++x) { int wi = wo * arg.conv_strides_[1] + x * arg.conv_dilations_[1] - arg.in_left_pads_[1]; if(hi >= 0 && hi < arg.input_.mDesc.GetLengths()[2] && wi >= 0 && wi < arg.input_.mDesc.GetLengths()[3]) { float v_in; float v_wei; arg.in_element_op_( v_in, ck::type_convert(arg.input_(n, c, hi, wi))); arg.wei_element_op_( v_wei, ck::type_convert(arg.weight_(k, c, y, x))); v_acc += v_in * v_wei; } } } } float v_out; arg.out_element_op_(v_out, v_acc); arg.output_(n, k, ho, wo) = ck::type_convert(v_out); }; make_ParallelTensorFunctor(f_nchw, arg.output_.mDesc.GetLengths()[0], arg.output_.mDesc.GetLengths()[1], arg.output_.mDesc.GetLengths()[2], arg.output_.mDesc.GetLengths()[3])( std::thread::hardware_concurrency()); return 0; } } float Run(const device::BaseArgument* p_arg, int) override { return Run(*dynamic_cast(p_arg)); } }; static constexpr bool IsValidCompilationParameter() { // TODO: properly implement this check return true; } bool IsSupportedArgument(const device::BaseArgument*) override { return true; } static auto MakeArgument(const Tensor& input, const Tensor& weight, Tensor& output, std::vector conv_filter_strides, std::vector conv_filter_dilations, std::vector input_left_pads, std::vector input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op) { return Argument{input, weight, output, conv_filter_strides, conv_filter_dilations, input_left_pads, input_right_pads, in_element_op, wei_element_op, out_element_op}; } static auto MakeInvoker() { return Invoker{}; } virtual std::unique_ptr MakeInvokerPointer() { return std::make_unique(Invoker{}); } std::string GetTypeString() const override { auto str = std::stringstream(); // clang-format off str << "ReferenceConvFwd" << std::endl; // clang-format on return str.str(); } }; } // namespace host } // namespace tensor_operation } // namespace ck #endif