#ifndef DEVICE_CONV_HPP #define DEVICE_CONV_HPP #include #include "device_base.hpp" namespace ck { namespace tensor_operation { namespace device { template struct DeviceConvFwd : public BaseOperator { virtual std::unique_ptr MakeArgumentPointer(const void* p_in, const void* p_wei, void* p_out, ck::index_t N, ck::index_t K, ck::index_t C, std::vector input_spatial_lengths, std::vector filter_spatial_lengths, std::vector output_spatial_lengths, std::vector conv_filter_strides, std::vector conv_filter_dilations, std::vector input_left_pads, std::vector input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op) = 0; virtual std::unique_ptr MakeInvokerPointer() = 0; }; template struct DeviceConvBwd : public BaseOperator { virtual std::unique_ptr MakeArgumentPointer(void* p_in, const void* p_wei, const void* p_out, ck::index_t N, ck::index_t K, ck::index_t C, std::vector input_spatial_lengths, std::vector filter_spatial_lengths, std::vector output_spatial_lengths, std::vector conv_filter_strides, std::vector conv_filter_dilations, std::vector input_left_pads, std::vector input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op) = 0; virtual std::unique_ptr MakeInvokerPointer() = 0; }; template struct DeviceConvWrw : public BaseOperator { virtual std::unique_ptr MakeArgumentPointer(const void* p_in, void* p_wei, const void* p_out, ck::index_t N, ck::index_t K, ck::index_t C, std::vector input_spatial_lengths, std::vector filter_spatial_lengths, std::vector output_spatial_lengths, std::vector conv_filter_strides, std::vector conv_filter_dilations, std::vector input_left_pads, std::vector input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op) = 0; virtual std::unique_ptr MakeInvokerPointer() = 0; }; template using DeviceConvFwdPtr = std::unique_ptr< DeviceConvFwd>; template using DeviceConvBwdPtr = std::unique_ptr< DeviceConvBwd>; template using DeviceConvWrwPtr = std::unique_ptr< DeviceConvWrw>; } // namespace device } // namespace tensor_operation } // namespace ck #endif