#ifndef __UNET_H__ #define __UNET_H__ #include #include #include #include using namespace std; using namespace cv; using namespace migraphx; namespace migraphxSamples { class Unet { public: Unet(); ~Unet(); ErrorCode Initialize(InitializationParameterOfSegmentation initParamOfSegmentationUnet); ErrorCode Segmentation(const cv::Mat &srcImage, cv::Mat &maskImage); private: ErrorCode DoCommonInitialization(InitializationParameterOfSegmentation initParamOfSegmentationUnet); float Sigmoid(float x); private: FILE *logFile; cv::FileStorage configurationFile; InitializationParameterOfSegmentation initializationParameter; migraphx::program net; cv::Size inputSize; std::string inputName; migraphx::shape inputShape; float scale; }; } #endif