#include <Sample.h>
#include <opencv2/dnn.hpp>
#include <SimpleLog.h>
#include <Filesystem.h>
#include <Unet.h>

using namespace std;
using namespace cv;
using namespace cv::dnn;
using namespace migraphx;
using namespace migraphxSamples;

void Sample_Unet()
{
    // 加载Unet模型
    Unet unet;
    InitializationParameterOfSegmentation initParamOfSegmentationUnet;
    initParamOfSegmentationUnet.parentPath = "";
    initParamOfSegmentationUnet.configFilePath = CONFIG_FILE;
    initParamOfSegmentationUnet.logName = "";
    ErrorCode errorCode=unet.Initialize(initParamOfSegmentationUnet);
    if(errorCode!=SUCCESS)
    {
        LOG_ERROR(stdout, "fail to initialize Unet!\n");
        exit(-1);
    }
    LOG_INFO(stdout, "succeed to initialize Unet\n");

    // 读取图像
    cv::Mat srcImage =cv::imread("../Resource/Images/car1.jpeg", 1);
   
    // 推理
    cv::Mat maskImage;   
    double time1 = getTickCount();
    unet.Segmentation(srcImage, maskImage);
    double time2 = getTickCount();
    double elapsedTime = (time2 - time1) * 1000 / getTickFrequency();
    LOG_INFO(stdout, "inference time:%f ms\n", elapsedTime);
    LOG_INFO(stdout,"========== Segmentation Results ==========\n");
    LOG_INFO(stdout,"Segmentation results have been saved to ./Result.jpg\n");
    cv::imwrite("./Result.jpg", maskImage);
}