#ifndef DEVICE_REDUCE_HPP #define DEVICE_REDUCE_HPP #include #include #include #include "common_header.hpp" #include "device_base.hpp" #include "reduction_enums.hpp" namespace ck { namespace tensor_operation { namespace device { template struct DeviceReduce : public BaseOperator { virtual size_t GetWorkspaceSizeInBytes(const std::vector& inLengths) { (void)inLengths; return (0); }; virtual bool HasFurtherCall() { return (false); }; virtual std::vector GetWorkspace2dLengths(const BaseArgument* argPtr) { (void)argPtr; return (std::vector{0, 0}); }; virtual std::unique_ptr MakeArgumentPointer(const std::vector& inLengths, const std::vector& inStrides, const std::vector& outLengths, const std::vector& outStrides, float alpha, float beta, const void* in_dev, void* out_dev, void* out_indices_dev, void* workspace_dev, const InElementwiseOperation& inElementwiseOp, const AccElementwiseOperation& accElementwiseOp) = 0; virtual std::unique_ptr MakeInvokerPointer() = 0; }; template using DeviceReducePtr = std::unique_ptr>; } // namespace device } // namespace tensor_operation } // namespace ck #endif