#include "ocr_engine.hpp"
#include <iostream>
#include <chrono>
#include <cmath>

using namespace migraphxSamples;

template <typename T>
T clip(const T &n, const T &lower, const T &upper){
    return std::max(lower, std::min(n, upper));
}

template<class ForwardIterator>
inline size_t argmax(ForwardIterator first, ForwardIterator last)
{
    return std::distance(first, std::max_element(first, last));
}

static float sigmoid(float x)
{
    return (1 / (1 + exp(-x)));
}

template <class T> inline T clamp(T x, T min, T max) {
    if (x > max)
      return max;
    if (x < min)
      return min;
    return x;
}

inline float clampf(float x, float min, float max) {
    if (x > max)
        return max;
    if (x < min)
        return min;
    return x;
}

inline int _max(int a, int b) { return a >= b ? a : b; }

inline int _min(int a, int b) { return a >= b ? b : a; }


bool XsortFp32(std::vector<float> a, std::vector<float> b) {
        if (a[0] != b[0])
            return a[0] < b[0];
        return false;
    }
    
    bool XsortInt(std::vector<int> a, std::vector<int> b) {
        if (a[0] != b[0])
            return a[0] < b[0];
        return false;
    }
namespace ppocr
{

    OcrDet::OcrDet(const std::string det_model_path,
            std::string precision_mode,
            bool offload_copy,
            float segm_thres,
            float box_thresh )
    {
       if(!Exists(det_model_path))
        {
            LOG_ERROR(stdout, "onnx file not exists!\n");
            exit(0);
        }
        
        this->max_candidates = 1000;
        this->det_batch_size = 1;
        this->segm_thres = segm_thres;
        this->box_thres  = box_thresh;
        this->precision_mode = precision_mode;

        
        
        migraphx::onnx_options onnx_options;
        onnx_options.map_input_dims["x"] = {8, 3, 640, 640};

        net = migraphx::parse_onnx(det_model_path,onnx_options);
        LOG_INFO(stdout, "Succeed to load model: %s\n", GetFileName(det_model_path).c_str());

        if(this->precision_mode.compare("fp16")==0)
        {
            LOG_INFO(stdout, "Set precison mode: %s\n",this->precision_mode.c_str());
            migraphx::quantize_fp16(net);
        }


        std::unordered_map<std::string, migraphx::shape> inputs  = net.get_inputs();
        std::unordered_map<std::string, migraphx::shape> outputs = net.get_outputs();
        this->input_name   = inputs.begin()->first;
        this->input_shape  = inputs.begin()->second;
        this->output_name  = outputs.begin()->first;
        this->output_shape = outputs.begin()->second;

        int N            = this->input_shape.lens()[0];
        int C            = this->input_shape.lens()[1];
        int H            = this->input_shape.lens()[2];
        int W            = this->input_shape.lens()[3];
        this->data_size  = N*C*H*W;
        data =(float*)malloc(C*H*W*sizeof(float));

        net_input_width = W;
        net_input_height = H;
        net_input_channel = C;

        n_channel     =  this->output_shape.lens()[1];
        output_width  =  this->output_shape.lens()[3];
        output_height =  this->output_shape.lens()[2];
        feature_size  =  output_width*output_height;

        
        this->offload_copy = offload_copy;
        migraphx::compile_options options;
        options.device_id = 0; // default device cuda:0
        options.offload_copy = offload_copy;
        migraphx::target gpuTarget = migraphx::gpu::target{};
        net.compile(gpuTarget, options);
        if( this->offload_copy ==false )
        {
            hipMalloc(&input_buffer_device, this->input_shape.bytes());
            hipMalloc(&output_buffer_device, this->output_shape.bytes());
            output_buffer_host   =  (void*)malloc(this->output_shape.bytes());

            dev_argument[input_name]  = migraphx::argument{input_shape, input_buffer_device};
            dev_argument[output_name] = migraphx::argument{output_shape, output_buffer_device};
        }

        //decode
        // ocr = std::make_shared<CTCDecode>(res_mpath,100,32,3,keys_path);
    }

    OcrDet::~OcrDet()
    {
        if(data)
        {
            free(data);
            data = nullptr;
        }
        if( offload_copy == false )
        {
            if(input_buffer_device)
            {
                hipFree(input_buffer_device);
            }
            if(output_buffer_device)
            {
                hipFree(output_buffer_device);
            }

            if(output_buffer_host)
            {
                free(output_buffer_host);
            }
        }
    }

    cv::Size OcrDet::preproc(cv::Mat img,float* data)
    {
        float scale = 1.0/255.0;
        std::vector<float> s_mean={0.485, 0.456, 0.406};
        std::vector<float> s_stdv={0.229, 0.224, 0.225};
        if(img.empty())
        {
            std::cout<<"Source image is empty!\n";
            return cv::Size(1.0,1.0);
        }
        cv::Mat res_img;
        cv::Size scale_r;
        scale_r.width = float(net_input_width)/float(img.cols);
        scale_r.height = float(net_input_height)/float(img.rows);

        cv::resize(img,res_img,cv::Size(net_input_width,net_input_height)); 
        int iw = res_img.cols;
        int ih = res_img.rows;
        memset(data,0.0,3*iw*ih*sizeof(float));
        for(int i=0;i<net_input_height;i++)
        {
            for(int j=0;j<net_input_width;j++)
            { 
                data[i*net_input_width+j+2*net_input_height*net_input_width] = (float(res_img.at<cv::Vec3b>(i, j)[2])*scale-s_mean[2])/s_stdv[2];
                data[i*net_input_width+j+net_input_height*net_input_width] =   (float(res_img.at<cv::Vec3b>(i, j)[1])*scale-s_mean[1])/s_stdv[1];
                data[i*net_input_width+j] =                                    (float(res_img.at<cv::Vec3b>(i, j)[0])*scale-s_mean[0])/s_stdv[0];   
            }
        }
        return  scale_r ;
    }
    
    std::vector<std::vector<float>> OcrDet::get_mini_boxes(cv::RotatedRect box,float &ssid) 
    {
        ssid = max(box.size.width, box.size.height);
        cv::Mat points;
        cv::boxPoints(box, points);

        auto array = Mat2Vector(points);
        std::sort(array.begin(), array.end(), XsortFp32);

        std::vector<float> idx1 = array[0], idx2 = array[1], idx3 = array[2],
                            idx4 = array[3];
        if (array[3][1] <= array[2][1]) {
            idx2 = array[3];
            idx3 = array[2];
        } else {
            idx2 = array[2];
            idx3 = array[3];
        }
        if (array[1][1] <= array[0][1]) {
            idx1 = array[1];
            idx4 = array[0];
        } else {
            idx1 = array[0];
            idx4 = array[1];
        }

        array[0] = idx1;
        array[1] = idx2;
        array[2] = idx3;
        array[3] = idx4;

        return array;
    }

    std::vector<std::vector<std::vector<int>>>OcrDet::boxes_from_bitmap(
    const cv::Mat pred, const cv::Mat bitmap, const float &box_thresh,
    const float &det_db_unclip_ratio, const bool &use_polygon_score) {
        const int min_size = 3;
        const int max_candidates = 1000;

        int width = bitmap.cols;
        int height = bitmap.rows;

        std::vector<std::vector<cv::Point>> contours;
        std::vector<cv::Vec4i> hierarchy;

        cv::findContours(bitmap, contours, hierarchy, cv::RETR_LIST,
                        cv::CHAIN_APPROX_SIMPLE);

        int num_contours =
            contours.size() >= max_candidates ? max_candidates : contours.size();

        std::vector<std::vector<std::vector<int>>> boxes;

        for (int _i = 0; _i < num_contours; _i++) {
            if (contours[_i].size() <= 2) {
            continue;
            }
            float ssid;
            cv::RotatedRect box = cv::minAreaRect(contours[_i]);
            auto array = get_mini_boxes(box, ssid);

            auto box_for_unclip = array;
            // end get_mini_box

            if (ssid < min_size) {
            continue;
            }

            float score;
            if (use_polygon_score)
            /* compute using polygon*/
            score = polygon_score_acc(contours[_i], pred);
            else
            score = box_score_fast(array, pred);

            if (score < box_thresh)
            continue;

            // start for unclip
            cv::RotatedRect points = unClip(box_for_unclip, det_db_unclip_ratio);
            if (points.size.height < 1.001 && points.size.width < 1.001) {
            continue;
            }
            // end for unclip

            cv::RotatedRect clipbox = points;
            auto cliparray = get_mini_boxes(clipbox, ssid);

            if (ssid < min_size + 2)
            continue;

            int dest_width = pred.cols;
            int dest_height = pred.rows;
            std::vector<std::vector<int>> intcliparray;

            for (int num_pt = 0; num_pt < 4; num_pt++) {
            std::vector<int> a{int(clampf(roundf(cliparray[num_pt][0] / float(width) *
                                                float(dest_width)),
                                            0, float(dest_width))),
                                int(clampf(roundf(cliparray[num_pt][1] /
                                                float(height) * float(dest_height)),
                                            0, float(dest_height)))};
            intcliparray.push_back(a);
            }
            boxes.push_back(intcliparray);

        } // end for
        return boxes;
    }

    std::vector<std::vector<float>> OcrDet::Mat2Vector(cv::Mat mat)
    {
        std::vector<std::vector<float>> img_vec;
        std::vector<float> tmp;

        for (int i = 0; i < mat.rows; ++i) {
            tmp.clear();
            for (int j = 0; j < mat.cols; ++j) {
            tmp.push_back(mat.at<float>(i, j));
            }
            img_vec.push_back(tmp);
        }
        return img_vec;
    }
    
    float  OcrDet::polygon_score_acc(std::vector<cv::Point> contour,
                                     cv::Mat pred)
    {
        int width = pred.cols;
        int height = pred.rows;
        std::vector<float> box_x;
        std::vector<float> box_y;
        for (int i = 0; i < contour.size(); ++i) {
            box_x.push_back(contour[i].x);
            box_y.push_back(contour[i].y);
        }

        int xmin =
            clamp(int(std::floor(*(std::min_element(box_x.begin(), box_x.end())))), 0,
                    width - 1);
        int xmax =
            clamp(int(std::ceil(*(std::max_element(box_x.begin(), box_x.end())))), 0,
                    width - 1);
        int ymin =
            clamp(int(std::floor(*(std::min_element(box_y.begin(), box_y.end())))), 0,
                    height - 1);
        int ymax =
            clamp(int(std::ceil(*(std::max_element(box_y.begin(), box_y.end())))), 0,
                    height - 1);

        cv::Mat mask;
        mask = cv::Mat::zeros(ymax - ymin + 1, xmax - xmin + 1, CV_8UC1);


        cv::Point* rook_point = new cv::Point[contour.size()];
        
        for (int i = 0; i < contour.size(); ++i) {
            rook_point[i] = cv::Point(int(box_x[i]) - xmin, int(box_y[i]) - ymin);
        }
        const cv::Point *ppt[1] = {rook_point};
        int npt[] = {int(contour.size())};


        cv::fillPoly(mask, ppt, npt, 1, cv::Scalar(1));

        cv::Mat croppedImg;
        pred(cv::Rect(xmin, ymin, xmax - xmin + 1, ymax - ymin + 1)).copyTo(croppedImg);
        float score = cv::mean(croppedImg, mask)[0];

        delete []rook_point;
        return score;
    }
    
    float OcrDet::box_score_fast(std::vector<std::vector<float>> box_array,
                                  cv::Mat pred) 
    {
        auto array = box_array;
        int width = pred.cols;
        int height = pred.rows;

        float box_x[4] = {array[0][0], array[1][0], array[2][0], array[3][0]};
        float box_y[4] = {array[0][1], array[1][1], array[2][1], array[3][1]};

        int xmin = clamp(int(std::floor(*(std::min_element(box_x, box_x + 4)))), 0,
                        width - 1);
        int xmax = clamp(int(std::ceil(*(std::max_element(box_x, box_x + 4)))), 0,
                        width - 1);
        int ymin = clamp(int(std::floor(*(std::min_element(box_y, box_y + 4)))), 0,
                        height - 1);
        int ymax = clamp(int(std::ceil(*(std::max_element(box_y, box_y + 4)))), 0,
                        height - 1);

        cv::Mat mask;
        mask = cv::Mat::zeros(ymax - ymin + 1, xmax - xmin + 1, CV_8UC1);

        cv::Point root_point[4];
        root_point[0] = cv::Point(int(array[0][0]) - xmin, int(array[0][1]) - ymin);
        root_point[1] = cv::Point(int(array[1][0]) - xmin, int(array[1][1]) - ymin);
        root_point[2] = cv::Point(int(array[2][0]) - xmin, int(array[2][1]) - ymin);
        root_point[3] = cv::Point(int(array[3][0]) - xmin, int(array[3][1]) - ymin);
        const cv::Point *ppt[1] = {root_point};
        int npt[] = {4};
        cv::fillPoly(mask, ppt, npt, 1, cv::Scalar(1));

        cv::Mat croppedImg;
        pred(cv::Rect(xmin, ymin, xmax - xmin + 1, ymax - ymin + 1))
            .copyTo(croppedImg);

        auto score = cv::mean(croppedImg, mask)[0];
        return score;
   }
    cv::RotatedRect OcrDet::unClip(std::vector<std::vector<float>> box,
                                      const float &unclip_ratio)
    {
        float distance = 1.0;
        get_contour_area(box, unclip_ratio, distance);
        ClipperLib::ClipperOffset offset;
        ClipperLib::Path p;
        p << ClipperLib::IntPoint(int(box[0][0]), int(box[0][1]))
            << ClipperLib::IntPoint(int(box[1][0]), int(box[1][1]))
            << ClipperLib::IntPoint(int(box[2][0]), int(box[2][1]))
            << ClipperLib::IntPoint(int(box[3][0]), int(box[3][1]));
        offset.AddPath(p, ClipperLib::jtRound, ClipperLib::etClosedPolygon);

        ClipperLib::Paths soln;
        offset.Execute(soln, distance);
        std::vector<cv::Point2f> points;

        for (int j = 0; j < soln.size(); j++) {
            for (int i = 0; i < soln[soln.size() - 1].size(); i++) {
            points.emplace_back(soln[j][i].X, soln[j][i].Y);
            }
        }
        cv::RotatedRect res;
        if (points.size() <= 0) {
            res = cv::RotatedRect(cv::Point2f(0, 0), cv::Size2f(1, 1), 0);
        } else {
            res = cv::minAreaRect(points);
        }
        return res;
    }
    
    void OcrDet::get_contour_area(const std::vector<std::vector<float>> &box,
                                   float unclip_ratio, float &distance) 
    {
        int pts_num = 4;
        float area = 0.0f;
        float dist = 0.0f;
        for (int i = 0; i < pts_num; i++) {
            area += box[i][0] * box[(i + 1) % pts_num][1] -
                    box[i][1] * box[(i + 1) % pts_num][0];
            dist += sqrtf((box[i][0] - box[(i + 1) % pts_num][0]) *
                            (box[i][0] - box[(i + 1) % pts_num][0]) +
                        (box[i][1] - box[(i + 1) % pts_num][1]) *
                            (box[i][1] - box[(i + 1) % pts_num][1]));
        }
        area = fabs(float(area / 2.0));
        distance = area * unclip_ratio / dist;
    }
    
    std::vector<std::vector<std::vector<int>>>
    OcrDet::filter_det_res(std::vector<std::vector<std::vector<int>>> boxes,
                                float ratio_h, float ratio_w, cv::Mat srcimg)
    {
        int oriimg_h = srcimg.rows;
        int oriimg_w = srcimg.cols;

        std::vector<std::vector<std::vector<int>>> root_points;
        for (int n = 0; n < boxes.size(); n++) {
            boxes[n] = order_points_clockwise(boxes[n]);
            for (int m = 0; m < boxes[0].size(); m++) {
            boxes[n][m][0] /= ratio_w;
            boxes[n][m][1] /= ratio_h;

            boxes[n][m][0] = int(_min(_max(boxes[n][m][0], 0), oriimg_w - 1));
            boxes[n][m][1] = int(_min(_max(boxes[n][m][1], 0), oriimg_h - 1));
            }
        }

        for (int n = 0; n < boxes.size(); n++) {
            int rect_width, rect_height;
            rect_width = int(sqrt(pow(boxes[n][0][0] - boxes[n][1][0], 2) +
                                pow(boxes[n][0][1] - boxes[n][1][1], 2)));
            rect_height = int(sqrt(pow(boxes[n][0][0] - boxes[n][3][0], 2) +
                                pow(boxes[n][0][1] - boxes[n][3][1], 2)));
            if (rect_width <= 4 || rect_height <= 4)
            continue;
            root_points.push_back(boxes[n]);
        }
        return root_points;
    }
    
    std::vector<std::vector<int>> OcrDet::order_points_clockwise(std::vector<std::vector<int>> pts)
    {
        std::vector<std::vector<int>> box = pts;
        std::sort(box.begin(), box.end(), XsortInt);
        std::vector<std::vector<int>> leftmost = {box[0], box[1]};
        std::vector<std::vector<int>> rightmost = {box[2], box[3]};

        if (leftmost[0][1] > leftmost[1][1])
            std::swap(leftmost[0], leftmost[1]);

        if (rightmost[0][1] > rightmost[1][1])
            std::swap(rightmost[0], rightmost[1]);

        std::vector<std::vector<int>> rect = {leftmost[0], rightmost[0], rightmost[1],
                                                leftmost[1]};
        return rect;
    }

    void OcrDet::visualize_boxes(const cv::Mat &srcimg,
        const std::vector<std::vector<std::vector<int>>> &boxes) 
    {
        cv::Mat img_vis;
        srcimg.copyTo(img_vis);
        for (int n = 0; n < boxes.size(); n++) {
            cv::Point rook_points[4];
            // std::cout<<"size :"<<boxes[n].size()<<'\n';
            for (int m = 0; m < boxes[n].size(); m++) {
            rook_points[m] = cv::Point(int(boxes[n][m][0]), int(boxes[n][m][1]));
            }

            const cv::Point *ppt[1] = {rook_points};
            int npt[] = {4};
            cv::polylines(img_vis, ppt, npt, 1, 1, CV_RGB(0, 255, 0), 2, 8, 0);
        }

        cv::imwrite("./ocr_debug.png", img_vis);
        std::cout << "image saved in ./ocr_result.png"
                    << std::endl;
    }
    
    bool OcrDet::text_recognition(const cv::Mat &srcimg,
        const std::vector<std::vector<std::vector<int>>> &boxes)
    {
        if(boxes.size() == 0)
        {
            std::cout<<"Not found text roi !\n";
            return false;
        }
        std::vector<cv::Point> boundingPoint;
        for (int n = 0; n < boxes.size(); n++) {
            
            cv::Rect rect;
            cv::Mat text_mat;
            rect.x = boxes[n][0][0];
            rect.y = boxes[n][0][1];
            rect.width = boxes[n][2][0] - boxes[n][0][0];
            rect.height = boxes[n][2][1] - boxes[n][0][1];
            text_mat = srcimg(rect).clone();
            // ocr->forward(text_mat);
            // cv::rectangle(srcimg,rect,cv::Scalar(0,255,0),2);
             
        }   
        // cv::imwrite("region_debug.jpg",srcimg);
        return true;
    }
    int OcrDet::postprocess(float* feature, std::vector<std::vector<std::vector<int>>> &boxes)
    {
        int batch_s = 1;
        float conf_thres = 0.6;
        cv::Mat thres_mat = cv::Mat(cv::Size(output_height,output_width), CV_8UC1);
        int feat_size = 20;
        for(int n =0; n< batch_s; n++)
        {
            for (int c =0 ;c<n_channel;c++)
            {
                for(int h = 0;h<output_height;h++)
                {
                    for(int w =0;w<output_width;w++)
                    {
                        thres_mat.at<uchar>(h,w) = feature[n*feature_size*n_channel+c*feature_size+h*output_width+w] > conf_thres ? 1: 0;
                    }
                }
            }
        }
        boxes.clear();
        cv::Mat dilation_map;
        cv::Mat dila_ele = cv::getStructuringElement(cv::MORPH_RECT, cv::Size(2, 2));
        cv::dilate(thres_mat, dilation_map, dila_ele);
        boxes = boxes_from_bitmap(thres_mat, dilation_map, 0.6,1.5, false);
        return 0;        
    }

    bool OcrDet::forward(cv::Mat& img,std::vector<std::vector<std::vector<int>>>& text_roi_boxes)
    {
        std::vector<std::vector<std::vector<int>>> boxes;
        cv::Size ratio = preproc(img,data);
        
        if( this->offload_copy ==false )
        {
            hipMemcpy(input_buffer_device,
                  (void*)data,
                  this->input_shape.bytes(),
                  hipMemcpyHostToDevice);

            std::vector<migraphx::argument> results = net.eval(dev_argument);
          
            hipMemcpy(output_buffer_host,
            (void*)output_buffer_device,
            output_shape.bytes(),
            hipMemcpyDeviceToHost);
            postprocess((float *)output_buffer_host,boxes);
            std::cout<<"copy mode ..."<<std::endl;
        }else{
            std::unordered_map<std::string, migraphx::argument> inputData;
            inputData[input_name] = migraphx::argument{input_shape, (float *)data};
            std::vector<migraphx::argument> results = net.eval(inputData);
            migraphx::argument result = results[0] ; //get output data  
            postprocess((float *)result.data(),boxes);
            std::cout<<"offload copy mode ..."<<std::endl;
        }
        
       
        float ratio_w = float(net_input_width) / float(img.cols);
        float ratio_h = float(net_input_height) / float(img.rows);
       
        text_roi_boxes = filter_det_res(boxes, ratio_h, ratio_w, img);
        visualize_boxes(img,text_roi_boxes);
        // TextRecognition(img,boxes);
        return true;
    }
 
    CTCDecode::CTCDecode(std::string rec_model_path,
        std::string precision_mode,
        int image_width,
        int image_height,
        int channel,
        int batch_size,
        bool offload_copy,
        std::string character_dict_path)
    {
        
        if(!Exists(rec_model_path))
        {
            LOG_ERROR(stdout, "onnx file not exists!\n");
            exit(0);
        }
        this->batch_size = batch_size;
        this->net_input_width=image_width;
        this->net_input_height=image_height;
        this->net_input_channel=channel;
        this->precision_mode = precision_mode;
        

        migraphx::onnx_options onnx_options;
        onnx_options.map_input_dims["x"] = {1, 3, 48, 720};

        net = migraphx::parse_onnx(rec_model_path,onnx_options);
        LOG_INFO(stdout, "Succeed to load model: %s %s\n", GetFileName(rec_model_path).c_str(),this->precision_mode.c_str());

        if(this->precision_mode.compare("fp16")==0)
        {
            LOG_INFO(stdout, "Set precison mode: %s\n",this->precision_mode.c_str());
            migraphx::quantize_fp16(net);
        }

        std::unordered_map<std::string, migraphx::shape> inputs  = net.get_inputs();
        std::unordered_map<std::string, migraphx::shape> outputs = net.get_outputs();
        this->input_name   = inputs.begin()->first;
        this->input_shape  = inputs.begin()->second;
        this->output_name  = outputs.begin()->first;
        this->output_shape = outputs.begin()->second;

        int N            = this->input_shape.lens()[0];
        int C            = this->input_shape.lens()[1];
        int H            = this->input_shape.lens()[2];
        int W            = this->input_shape.lens()[3];
        
        data =(float*)malloc(N*C*H*W*sizeof(float));

        this->feature_size = output_shape.lens()[2];
        n_channel = this->output_shape.lens()[1];
        std::cout<<"["<<this->output_shape.lens()[0]<<
        ","<<this->output_shape.lens()[1]<<","<<this->output_shape.lens()[2]<<"]\n";

        this->offload_copy = offload_copy;
        migraphx::compile_options options;
        options.device_id = 0; // default device cuda:0
        options.offload_copy = offload_copy;
        migraphx::target gpuTarget = migraphx::gpu::target{};
        net.compile(gpuTarget, options);

        if( this->offload_copy ==false )
        {
            LOG_INFO(stdout, "Set copy mode ...\n");
            hipMalloc(&input_buffer_device, this->input_shape.bytes());
            hipMalloc(&output_buffer_device, this->output_shape.bytes());
            output_buffer_host   =  (void*)malloc(this->output_shape.bytes());

            dev_argument[input_name]  = migraphx::argument{input_shape, input_buffer_device};
            dev_argument[output_name] = migraphx::argument{output_shape, output_buffer_device};
        }


        std::ifstream infile; 
        infile.open(character_dict_path,std::ios::in);    
        assert(infile.is_open()); 
        std::string k_work=""; 
        k_words.clear();
        while (std::getline(infile,k_work))
        {
            k_words.push_back(k_work);
        }
        system("chcp 65001");
    }

    CTCDecode::~CTCDecode()
    {
        if(data)
        {
            free(data);
            data = nullptr;
        }

        if( offload_copy == false )
        {
            if(input_buffer_device)
            {
                hipFree(input_buffer_device);
            }
            if(output_buffer_device)
            {
                hipFree(output_buffer_device);
            }

            if(output_buffer_host)
            {
                free(output_buffer_host);
            }
        }
    }

    bool CTCDecode::preproc(cv::Mat img,float* data,int img_w,int img_h)
    {
        if (img.empty())
        {
            std::cout<<"WARNING image is empty!\n";
            return false;
        }

        float scale=1.0/255.;
        int iw=img.cols;
        int ih=img.rows;
        float ratio=min(img_h*1.0/ih,img_w*1.0/iw);
        int nw=static_cast<int> (iw*ratio);
        int nh=img_h;
        cv::Mat res_mat;
        cv::resize(img,res_mat,cv::Size(nw,nh));
        cv::Mat template_mat=cv::Mat(img_h,img_w,CV_8UC3,cv::Scalar(0,0,0));
        int xdet=img_w-nw;
        int ydet=img_h-nh;
        cv::copyMakeBorder(res_mat, template_mat, 0,ydet, 0, xdet, 0); 
        memset(data,0.0,this->batch_size*3*img_w*img_h*sizeof(float));
      
        for(int b =0 ; b < this->batch_size;b++ )
        {
            for(int i=0;i<img_h;i++)
            {
                for(int j=0;j<img_w;j++)
                { 
                    data[i*img_w+j] = (template_mat.at<cv::Vec3b>(i, j)[2]*scale-0.5)/0.5;
                    data[i*img_w+j+img_h*img_w] = (template_mat.at<cv::Vec3b>(i, j)[1]*scale-0.5)/0.5;
                    data[i*img_w+j+2*img_h*img_w] =( template_mat.at<cv::Vec3b>(i, j)[0]*scale-0.5)/0.5;  
                
                }
            }
        }
        return  true ;
    }

    std::string CTCDecode::decode(std::vector<float>& probs,std::vector<int>& indexs,float& mean_prob)
    {
        int ignored_tokens=0;
        std::string text="";
        std::vector<float> n_probs;
        std::vector<int> n_indexs;
        int eff_text_num=0;
    

        for (int i=0;i<n_channel;i++)
        {
            // std::cout<<"s :"<<i<<":"<<indexs[i]<<"-"<<probs[i]<<std::endl;
            if(indexs[i]==ignored_tokens)
            {
                continue;
            }
            if(i>0 && indexs[i-1]==indexs[i])
            {
                continue;
            }

            mean_prob+=probs[i];
            text+=k_words[indexs[i]-1];

            eff_text_num++;
        }

        if(eff_text_num!=0)
        {
            mean_prob/=eff_text_num;
        }
        else
        {
            mean_prob = 0.;
        }
        
        return text;
    } 
    std::string CTCDecode::postprocess(float* feature)
    {
        //shape 25*6625
        
        std::vector<float> probs;
        std::vector<int> indexs;
        float prob=0.;
        // std::cout<<"n_channel:"<<n_channel<<", feature_size:"<<feature_size<<std::endl;
        for (int i=0;i<n_channel;i++)
        {
            float* c_feat = feature+i*feature_size;
            int max_index = argmax<float*>(c_feat,c_feat+feature_size);
            float max_pro = c_feat[max_index];

            // std::cout<<"step:"<<i<<"  max_pro:"<<max_pro<<", max_index:"<<max_index<<std::endl;
            probs.push_back(max_pro);
            indexs.push_back(max_index);
        }
        
        std::string text = decode(probs,indexs,prob);
        std::cout<<"ocr res :["<<text<<"]\n";
        
        return text;
    }

    std::string  CTCDecode::forward(cv::Mat& img)
    {
        preproc(img,data,net_input_width,net_input_height);

        // std::unordered_map<std::string, migraphx::argument> inputData;
        // inputData[input_name] = migraphx::argument{input_shape, data};
        // std::vector<migraphx::argument> results = net.eval(inputData);
        // migraphx::argument result = results[0];

        if( this->offload_copy ==false )
        {
            hipMemcpy(input_buffer_device,
                  (void*)data,
                  this->input_shape.bytes(),
                  hipMemcpyHostToDevice);

            std::vector<migraphx::argument> results = net.eval(dev_argument);
          
            hipMemcpy(output_buffer_host,
            (void*)output_buffer_device,
            output_shape.bytes(),
            hipMemcpyDeviceToHost);
            
            // std::cout<<"ctc: copy mode ..."<<std::endl;
            std::string text = postprocess((float *)output_buffer_device);
            return text;
        }else{
            std::unordered_map<std::string, migraphx::argument> inputData;
            inputData[input_name] = migraphx::argument{input_shape, (float *)data};
            std::vector<migraphx::argument> results = net.eval(inputData);
            migraphx::argument result = results[0] ;  
            std::string text = postprocess((float *)result.data());
            // std::cout<<"ctc: offload copy mode ..."<<std::endl;
            return text;
        }
        
        
        
        //get output data (first node)
        // migraphx::shape outputShape = result.get_shape();
        // int numberOfOutput = outputShape.elements();

        // std::vector<std::size_t> outputSize = outputShape.lens();
        // std::cout<<"output size:"<<outputSize.size()<<std::endl;
        // for(int i = 0; i < outputSize.size(); i++)
        // {
        //     std::cout << outputSize[i] << " ";
        // }
        
    }

    ppOcrEngine::ppOcrEngine(const std::string &det_model_path,
                    const std::string &rec_model_path,
                    const std::string &character_dict_path,
                    float segm_thres,
                    float box_thresh,
                    bool offload_copy,
                    std::string precision_mode ){
        text_detector = std::make_shared<OcrDet>(det_model_path,precision_mode,offload_copy,segm_thres,box_thresh);
        text_recognizer = std::make_shared<CTCDecode>(rec_model_path,precision_mode,720,48,3,1,offload_copy,character_dict_path);

    }

    ppOcrEngine::~ppOcrEngine()
    {
        ;
    }
    
    std::vector<std::string> ppOcrEngine::forward(cv::Mat &srcimg)
    {
        std::vector<std::vector<std::vector<int>>> text_roi_boxes;

        std::vector<std::string> text_vec;
        auto start = std::chrono::high_resolution_clock::now();
        text_detector->forward(srcimg,text_roi_boxes);
        if(text_roi_boxes.size() == 0)
        {
            std::cout<<"Not found text roi !\n";
            return std::vector<std::string>();
        }
        std::cout<<"text_roi_boxes.size(): "<<text_roi_boxes.size()<<"\n";
        for (int n = 0; n < text_roi_boxes.size(); n++) {
            
            cv::Rect rect;
            cv::Mat text_roi_mat;
            rect.x = text_roi_boxes[n][0][0];
            rect.y = text_roi_boxes[n][0][1];
            rect.width = text_roi_boxes[n][2][0] -  text_roi_boxes[n][0][0];
            rect.height = text_roi_boxes[n][2][1] - text_roi_boxes[n][0][1];
            if(rect.width <3 || rect.height<3)
            {
                continue;
            }
            text_roi_mat = srcimg(rect).clone();
            std::string text = text_recognizer->forward(text_roi_mat);
            text_vec.push_back(text);
        }  
        auto end = std::chrono::high_resolution_clock::now(); 
        auto duration_ms = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
        std::cout<<"Time taken by task: "<< duration_ms.count() <<" ms\n";
        return text_vec;
    }

}




 