#include <OcrDB.h>
#include <migraphx/onnx.hpp>
#include <migraphx/gpu/target.hpp>
#include <Filesystem.h>
#include <SimpleLog.h>

using namespace cv::dnn;

namespace migraphxSamples
{

DB::DB()
{

}

DB::~DB()
{
    
    configurationFile.release();

}

ErrorCode DB::Initialize(InitializationParameterOfDB InitializationParameterOfDB)
{
    // 读取配置文件
    std::string configFilePath=InitializationParameterOfDB.configFilePath;
    if(Exists(configFilePath)==false)
    {
        LOG_ERROR(stdout, "no configuration file!\n");
        return CONFIG_FILE_NOT_EXIST;
    }
    if(!configurationFile.open(configFilePath, cv::FileStorage::READ))
    {
       LOG_ERROR(stdout, "fail to open configuration file\n");
       return FAIL_TO_OPEN_CONFIG_FILE;
    }
    LOG_INFO(stdout, "succeed to open configuration file\n");

    // 获取配置文件参数
    cv::FileNode netNode = configurationFile["OcrDB"];
    std::string modelPath = (string)netNode["ModelPath"];
    dbParameter.BinaryThreshold = (float)netNode["BinaryThreshold"];
    dbParameter.BoxThreshold = (float)netNode["BoxThreshold"];
    dbParameter.UnclipRatio = (float)netNode["UnclipRatio"];
    dbParameter.LimitSideLen = (int)netNode["LimitSideLen"];
    dbParameter.ScoreMode = (string)netNode["ScoreMode"];

    // 加载模型
    if(Exists(modelPath)==false)
    {
        LOG_ERROR(stdout,"%s not exist!\n",modelPath.c_str());
        return MODEL_NOT_EXIST;
    }
    migraphx::onnx_options onnx_options;
    onnx_options.map_input_dims["x"]={1,3,2496,2496}; // 设置最大shape
    net = migraphx::parse_onnx(modelPath, onnx_options);
    LOG_INFO(stdout,"succeed to load model: %s\n",GetFileName(modelPath).c_str());

    // 获取模型输入/输出节点信息
    std::cout<<"DB_inputs:"<<std::endl;
    std::unordered_map<std::string, migraphx::shape> inputs=net.get_inputs();
    for(auto i:inputs)
    {
        std::cout<<i.first<<":"<<i.second<<std::endl;
    }
    std::cout<<"DB_outputs:"<<std::endl;
    std::unordered_map<std::string, migraphx::shape> outputs=net.get_outputs();
    for(auto i:outputs)
    {
        std::cout<<i.first<<":"<<i.second<<std::endl;
    }
    inputName=inputs.begin()->first;
    inputShape=inputs.begin()->second;
    int N=inputShape.lens()[0];
    int C=inputShape.lens()[1];
    int H=inputShape.lens()[2];
    int W=inputShape.lens()[3];
    inputSize=cv::Size(W,H);

    // 设置模型为GPU模式
    migraphx::target gpuTarget = migraphx::gpu::target{};

    // 编译模型
    migraphx::compile_options options;
    options.device_id=0;                          // 设置GPU设备，默认为0号设备
    options.offload_copy=true;                    
    net.compile(gpuTarget,options);               
    LOG_INFO(stdout,"succeed to compile model: %s\n",GetFileName(modelPath).c_str());

    // warm up
    std::unordered_map<std::string, migraphx::argument> inputData;
    inputData[inputName]=migraphx::argument{inputShape};
    net.eval(inputData);

    // log
    LOG_INFO(stdout,"InputMaxSize:%dx%d\n",inputSize.width,inputSize.height);
    LOG_INFO(stdout,"InputName:%s\n",inputName.c_str());                         

    return SUCCESS;
}

ErrorCode DB::Infer(const cv::Mat &img, std::vector<cv::Mat> &imgList)
{
    if(img.empty()||img.type()!=CV_8UC3)
    {
        LOG_ERROR(stdout, "image error!\n");
        return IMAGE_ERROR;
    }

    cv::Mat srcImage;
    cv::Mat resizeImg;
    img.copyTo(srcImage);

    int w = srcImage.cols;
    int h = srcImage.rows;
    float ratio = 1.f;
    int maxWH = std::max(h, w);
    if (maxWH > dbParameter.LimitSideLen)
    {
        if (h > w)
        {
            ratio = float(dbParameter.LimitSideLen) / float(h);
        }
        else
        {
            ratio = float(dbParameter.LimitSideLen) / float(w);
        }
    }  

    int resizeH = int(float(h) * ratio);
    int resizeW = int(float(w) * ratio);
    resizeH = std::max(int(round(float(resizeH) / 32) * 32), 32);
    resizeW = std::max(int(round(float(resizeW) / 32) * 32), 32);
    cv::resize(srcImage, resizeImg, cv::Size(resizeW, resizeH));

    float ratioH = float(resizeH) / float(h);
    float ratioW = float(resizeW) / float(w);

    resizeImg.convertTo(resizeImg, CV_32FC3, 1.0/255.0);
    std::vector<cv::Mat> bgrChannels(3);
    cv::split(resizeImg, bgrChannels);

    std::vector<float> mean = {0.485f, 0.456f, 0.406f};
    std::vector<float> scale = {1 / 0.229f, 1 / 0.224f, 1 / 0.225f};
    for (auto i = 0; i < bgrChannels.size(); i++)
    {
        bgrChannels[i].convertTo(bgrChannels[i], CV_32FC1, 1.0 * scale[i],
                              (0.0 - mean[i]) * scale[i]);
    }
    cv::merge(bgrChannels, resizeImg);
    int rh = resizeImg.rows;
    int rw = resizeImg.cols;
    cv::Mat inputBlob;
    inputBlob = cv::dnn::blobFromImage(resizeImg);
    std::vector<std::size_t> inputShapeOfInfer={1,3,rh,rw};

    // 创建输入数据
    std::unordered_map<std::string, migraphx::argument> inputData;
    inputData[inputName]= migraphx::argument{migraphx::shape(inputShape.type(),inputShapeOfInfer), (float*)inputBlob.data};

    // 推理
    std::vector<migraphx::argument> inferenceResults = net.eval(inputData);
    // 获取推理结果
    migraphx::argument result = inferenceResults[0]; 

    // 转换为vector
    migraphx::shape outputShape = result.get_shape();
    int shape[]={outputShape.lens()[0],outputShape.lens()[1],outputShape.lens()[2],outputShape.lens()[3]};
    int n2 = outputShape.lens()[2];
    int n3 = outputShape.lens()[3];
    int n = n2 * n3;
    std::vector<float> out(n);
    memcpy(out.data(),result.data(),sizeof(float)*outputShape.elements());   
    out.resize(n);
    
    std::vector<float> pred(n, 0.0);
    std::vector<unsigned char> cbuf(n, ' ');
    for (int i = 0; i < n; i++)
    {
        pred[i] = (float)(out[i]);
        cbuf[i] = (unsigned char)((out[i]) * 255);
    }

    cv::Mat cbufMap(n2, n3, CV_8UC1, (unsigned char *)cbuf.data());
    cv::Mat predMap(n2, n3, CV_32F, (float *)pred.data());
    const double threshold = dbParameter.BinaryThreshold * 255;
    const double maxvalue = 255;
    cv::Mat bitMap;
    cv::threshold(cbufMap, bitMap, threshold, maxvalue, cv::THRESH_BINARY);

    std::vector<std::vector<std::vector<int>>> boxes;
    DBPostProcessor postProcessor;
    boxes = postProcessor.BoxesFromBitmap(predMap, bitMap, dbParameter.BoxThreshold, dbParameter.UnclipRatio, dbParameter.ScoreMode);
    boxes = postProcessor.FilterTagDetRes(boxes, ratioH, ratioW, srcImage);

    std::vector<migraphxSamples::OCRPredictResult> ocrResults; 
    for (int i = 0; i < boxes.size(); i++) 
    {
        OCRPredictResult res;
        res.box = boxes[i];
        ocrResults.push_back(res);
    }
    Utility::sorted_boxes(ocrResults);

    for (int j = 0; j < ocrResults.size(); j++)
    {
        cv::Mat cropImg;
        cropImg = Utility::GetRotateCropImage(img, ocrResults[j].box);
        imgList.push_back(cropImg);
    }
}

void DBPostProcessor::GetContourArea(const std::vector<std::vector<float>> &box,
                                     float unclip_ratio, float &distance) {
  int pts_num = 4;
  float area = 0.0f;
  float dist = 0.0f;
  for (int i = 0; i < pts_num; i++) {
    area += box[i][0] * box[(i + 1) % pts_num][1] -
            box[i][1] * box[(i + 1) % pts_num][0];
    dist += sqrtf((box[i][0] - box[(i + 1) % pts_num][0]) *
                      (box[i][0] - box[(i + 1) % pts_num][0]) +
                  (box[i][1] - box[(i + 1) % pts_num][1]) *
                      (box[i][1] - box[(i + 1) % pts_num][1]));
  }
  area = fabs(float(area / 2.0));

  distance = area * unclip_ratio / dist;
}

cv::RotatedRect DBPostProcessor::UnClip(std::vector<std::vector<float>> box,
                                        const float &unclip_ratio) {
  float distance = 1.0;

  GetContourArea(box, unclip_ratio, distance);

  ClipperLib::ClipperOffset offset;
  ClipperLib::Path p;
  p << ClipperLib::IntPoint(int(box[0][0]), int(box[0][1]))
    << ClipperLib::IntPoint(int(box[1][0]), int(box[1][1]))
    << ClipperLib::IntPoint(int(box[2][0]), int(box[2][1]))
    << ClipperLib::IntPoint(int(box[3][0]), int(box[3][1]));
  offset.AddPath(p, ClipperLib::jtRound, ClipperLib::etClosedPolygon);

  ClipperLib::Paths soln;
  offset.Execute(soln, distance);
  std::vector<cv::Point2f> points;

  for (int j = 0; j < soln.size(); j++) {
    for (int i = 0; i < soln[soln.size() - 1].size(); i++) {
      points.emplace_back(soln[j][i].X, soln[j][i].Y);
    }
  }
  cv::RotatedRect res;
  if (points.size() <= 0) {
    res = cv::RotatedRect(cv::Point2f(0, 0), cv::Size2f(1, 1), 0);
  } else {
    res = cv::minAreaRect(points);
  }
  return res;
}

float **DBPostProcessor::Mat2Vec(cv::Mat mat) {
  auto **array = new float *[mat.rows];
  for (int i = 0; i < mat.rows; ++i)
    array[i] = new float[mat.cols];
  for (int i = 0; i < mat.rows; ++i) {
    for (int j = 0; j < mat.cols; ++j) {
      array[i][j] = mat.at<float>(i, j);
    }
  }

  return array;
}

std::vector<std::vector<int>>
DBPostProcessor::OrderPointsClockwise(std::vector<std::vector<int>> pts) {
  std::vector<std::vector<int>> box = pts;
  std::sort(box.begin(), box.end(), XsortInt);

  std::vector<std::vector<int>> leftmost = {box[0], box[1]};
  std::vector<std::vector<int>> rightmost = {box[2], box[3]};

  if (leftmost[0][1] > leftmost[1][1])
    std::swap(leftmost[0], leftmost[1]);

  if (rightmost[0][1] > rightmost[1][1])
    std::swap(rightmost[0], rightmost[1]);

  std::vector<std::vector<int>> rect = {leftmost[0], rightmost[0], rightmost[1],
                                        leftmost[1]};
  return rect;
}

std::vector<std::vector<float>> DBPostProcessor::Mat2Vector(cv::Mat mat) {
  std::vector<std::vector<float>> img_vec;
  std::vector<float> tmp;

  for (int i = 0; i < mat.rows; ++i) {
    tmp.clear();
    for (int j = 0; j < mat.cols; ++j) {
      tmp.push_back(mat.at<float>(i, j));
    }
    img_vec.push_back(tmp);
  }
  return img_vec;
}

bool DBPostProcessor::XsortFp32(std::vector<float> a, std::vector<float> b) {
  if (a[0] != b[0])
    return a[0] < b[0];
  return false;
}

bool DBPostProcessor::XsortInt(std::vector<int> a, std::vector<int> b) {
  if (a[0] != b[0])
    return a[0] < b[0];
  return false;
}

std::vector<std::vector<float>>
DBPostProcessor::GetMiniBoxes(cv::RotatedRect box, float &ssid) {
  ssid = std::max(box.size.width, box.size.height);

  cv::Mat points;
  cv::boxPoints(box, points);

  auto array = Mat2Vector(points);
  std::sort(array.begin(), array.end(), XsortFp32);

  std::vector<float> idx1 = array[0], idx2 = array[1], idx3 = array[2],
                     idx4 = array[3];
  if (array[3][1] <= array[2][1]) {
    idx2 = array[3];
    idx3 = array[2];
  } else {
    idx2 = array[2];
    idx3 = array[3];
  }
  if (array[1][1] <= array[0][1]) {
    idx1 = array[1];
    idx4 = array[0];
  } else {
    idx1 = array[0];
    idx4 = array[1];
  }

  array[0] = idx1;
  array[1] = idx2;
  array[2] = idx3;
  array[3] = idx4;

  return array;
}

float DBPostProcessor::PolygonScoreAcc(std::vector<cv::Point> contour,
                                       cv::Mat pred) {
  int width = pred.cols;
  int height = pred.rows;
  std::vector<float> box_x;
  std::vector<float> box_y;
  for (int i = 0; i < contour.size(); ++i) {
    box_x.push_back(contour[i].x);
    box_y.push_back(contour[i].y);
  }

  int xmin =
      clamp(int(std::floor(*(std::min_element(box_x.begin(), box_x.end())))), 0,
            width - 1);
  int xmax =
      clamp(int(std::ceil(*(std::max_element(box_x.begin(), box_x.end())))), 0,
            width - 1);
  int ymin =
      clamp(int(std::floor(*(std::min_element(box_y.begin(), box_y.end())))), 0,
            height - 1);
  int ymax =
      clamp(int(std::ceil(*(std::max_element(box_y.begin(), box_y.end())))), 0,
            height - 1);

  cv::Mat mask;
  mask = cv::Mat::zeros(ymax - ymin + 1, xmax - xmin + 1, CV_8UC1);

  cv::Point *rook_point = new cv::Point[contour.size()];

  for (int i = 0; i < contour.size(); ++i) {
    rook_point[i] = cv::Point(int(box_x[i]) - xmin, int(box_y[i]) - ymin);
  }
  const cv::Point *ppt[1] = {rook_point};
  int npt[] = {int(contour.size())};

  cv::fillPoly(mask, ppt, npt, 1, cv::Scalar(1));

  cv::Mat croppedImg;
  pred(cv::Rect(xmin, ymin, xmax - xmin + 1, ymax - ymin + 1))
      .copyTo(croppedImg);
  float score = cv::mean(croppedImg, mask)[0];

  delete[] rook_point;
  return score;
}

float DBPostProcessor::BoxScoreFast(std::vector<std::vector<float>> box_array,
                                    cv::Mat pred) {
  auto array = box_array;
  int width = pred.cols;
  int height = pred.rows;

  float box_x[4] = {array[0][0], array[1][0], array[2][0], array[3][0]};
  float box_y[4] = {array[0][1], array[1][1], array[2][1], array[3][1]};

  int xmin = clamp(int(std::floor(*(std::min_element(box_x, box_x + 4)))), 0,
                   width - 1);
  int xmax = clamp(int(std::ceil(*(std::max_element(box_x, box_x + 4)))), 0,
                   width - 1);
  int ymin = clamp(int(std::floor(*(std::min_element(box_y, box_y + 4)))), 0,
                   height - 1);
  int ymax = clamp(int(std::ceil(*(std::max_element(box_y, box_y + 4)))), 0,
                   height - 1);

  cv::Mat mask;
  mask = cv::Mat::zeros(ymax - ymin + 1, xmax - xmin + 1, CV_8UC1);

  cv::Point root_point[4];
  root_point[0] = cv::Point(int(array[0][0]) - xmin, int(array[0][1]) - ymin);
  root_point[1] = cv::Point(int(array[1][0]) - xmin, int(array[1][1]) - ymin);
  root_point[2] = cv::Point(int(array[2][0]) - xmin, int(array[2][1]) - ymin);
  root_point[3] = cv::Point(int(array[3][0]) - xmin, int(array[3][1]) - ymin);
  const cv::Point *ppt[1] = {root_point};
  int npt[] = {4};
  cv::fillPoly(mask, ppt, npt, 1, cv::Scalar(1));

  cv::Mat croppedImg;
  pred(cv::Rect(xmin, ymin, xmax - xmin + 1, ymax - ymin + 1))
      .copyTo(croppedImg);

  auto score = cv::mean(croppedImg, mask)[0];
  return score;
}

std::vector<std::vector<std::vector<int>>> DBPostProcessor::BoxesFromBitmap(
    const cv::Mat pred, const cv::Mat bitmap, const float &box_thresh,
    const float &det_db_unclip_ratio, const std::string &det_db_score_mode) {
  const int min_size = 3;
  const int max_candidates = 2000;

  int width = bitmap.cols;
  int height = bitmap.rows;

  std::vector<std::vector<cv::Point>> contours;
  std::vector<cv::Vec4i> hierarchy;

  cv::findContours(bitmap, contours, hierarchy, cv::RETR_LIST,
                   cv::CHAIN_APPROX_SIMPLE);

  int num_contours =
      contours.size() >= max_candidates ? max_candidates : contours.size();

  std::vector<std::vector<std::vector<int>>> boxes;

  for (int _i = 0; _i < num_contours; _i++) {
    if (contours[_i].size() <= 2) {
      continue;
    }
    float ssid;
    cv::RotatedRect box = cv::minAreaRect(contours[_i]);
    auto array = GetMiniBoxes(box, ssid);

    auto box_for_unclip = array;
    // end get_mini_box

    if (ssid < min_size) {
      continue;
    }

    float score;
    if (det_db_score_mode == "slow")
      /* compute using polygon*/
      score = PolygonScoreAcc(contours[_i], pred);
    else
      score = BoxScoreFast(array, pred);

    if (score < box_thresh)
      continue;

    // start for unclip
    cv::RotatedRect points = UnClip(box_for_unclip, det_db_unclip_ratio);
    if (points.size.height < 1.001 && points.size.width < 1.001) {
      continue;
    }
    // end for unclip

    cv::RotatedRect clipbox = points;
    auto cliparray = GetMiniBoxes(clipbox, ssid);

    if (ssid < min_size + 2)
      continue;

    int dest_width = pred.cols;
    int dest_height = pred.rows;
    std::vector<std::vector<int>> intcliparray;

    for (int num_pt = 0; num_pt < 4; num_pt++) {
      std::vector<int> a{int(clampf(roundf(cliparray[num_pt][0] / float(width) *
                                           float(dest_width)),
                                    0, float(dest_width))),
                         int(clampf(roundf(cliparray[num_pt][1] /
                                           float(height) * float(dest_height)),
                                    0, float(dest_height)))};
      intcliparray.push_back(a);
    }
    boxes.push_back(intcliparray);

  } // end for
  return boxes;
}

std::vector<std::vector<std::vector<int>>> DBPostProcessor::FilterTagDetRes(
    std::vector<std::vector<std::vector<int>>> boxes, float ratio_h,
    float ratio_w, cv::Mat srcimg) {
  int oriimg_h = srcimg.rows;
  int oriimg_w = srcimg.cols;

  std::vector<std::vector<std::vector<int>>> root_points;
  for (int n = 0; n < boxes.size(); n++) {
    boxes[n] = OrderPointsClockwise(boxes[n]);
    for (int m = 0; m < boxes[0].size(); m++) {
      boxes[n][m][0] /= ratio_w;
      boxes[n][m][1] /= ratio_h;

      boxes[n][m][0] = int(_min(_max(boxes[n][m][0], 0), oriimg_w - 1));
      boxes[n][m][1] = int(_min(_max(boxes[n][m][1], 0), oriimg_h - 1));
    }
  }

  for (int n = 0; n < boxes.size(); n++) {
    int rect_width, rect_height;
    rect_width = int(sqrt(pow(boxes[n][0][0] - boxes[n][1][0], 2) +
                          pow(boxes[n][0][1] - boxes[n][1][1], 2)));
    rect_height = int(sqrt(pow(boxes[n][0][0] - boxes[n][3][0], 2) +
                           pow(boxes[n][0][1] - boxes[n][3][1], 2)));
    if (rect_width <= 4 || rect_height <= 4)
      continue;
    root_points.push_back(boxes[n]);
  }
  return root_points;
}

}