#ifndef DEVICE_CONV_HPP #define DEVICE_CONV_HPP #include #include "device_base.hpp" namespace ck { namespace tensor_operation { namespace device { struct DeviceConvFwd : public BaseOperator { virtual std::unique_ptr MakeArgumentPointer(const void* p_in, const void* p_wei, void* p_out, ck::index_t N, ck::index_t K, ck::index_t C, std::vector input_spatial_lengths, std::vector filter_spatial_lengths, std::vector output_spatial_lengths, std::vector conv_filter_strides, std::vector conv_filter_dilations, std::vector input_left_pads, std::vector input_right_pads) = 0; virtual std::unique_ptr MakeInvokerPointer() = 0; }; struct DeviceConvBwd : public BaseOperator { virtual std::unique_ptr MakeArgumentPointer(void* p_in, const void* p_wei, const void* p_out, ck::index_t N, ck::index_t K, ck::index_t C, std::vector input_spatial_lengths, std::vector filter_spatial_lengths, std::vector output_spatial_lengths, std::vector conv_filter_strides, std::vector conv_filter_dilations, std::vector input_left_pads, std::vector input_right_pads) = 0; virtual std::unique_ptr MakeInvokerPointer() = 0; }; struct DeviceConvWrw : public BaseOperator { virtual std::unique_ptr MakeArgumentPointer(const void* p_in, void* p_wei, const void* p_out, ck::index_t N, ck::index_t K, ck::index_t C, std::vector input_spatial_lengths, std::vector filter_spatial_lengths, std::vector output_spatial_lengths, std::vector conv_filter_strides, std::vector conv_filter_dilations, std::vector input_left_pads, std::vector input_right_pads) = 0; virtual std::unique_ptr MakeInvokerPointer() = 0; }; using DeviceConvFwdPtr = std::unique_ptr; using DeviceConvBwdPtr = std::unique_ptr; using DeviceConvWrwPtr = std::unique_ptr; } // namespace device } // namespace tensor_operation } // namespace ck #endif