import cv2
import copy
import numpy as np
import json
import time
from utils import get_image_file_list
from utils import get_rotate_crop_image
from dbnet.dbnet_infer import DBNET
from crnn.CRNN import CRNNHandle
from angnet.angle import AngleNetHandle
from PIL import Image
import os
import argparse


def str2bool(v):
    return v.lower() in ("true", "t", "1")

parser = argparse.ArgumentParser()

parser.add_argument("--warmup", type=str2bool, default=False)
parser.add_argument("--use_angle_cls", type=str2bool, default=False)
parser.add_argument("--img_dir", type=str)
parser.add_argument("--det_model_dir", type=str)
parser.add_argument("--rec_model_dir", type=str)
parser.add_argument("--cls_model_dir", type=str)
args = parser.parse_args()

def sorted_boxes(dt_boxes):
    """
    Sort text boxes in order from top to bottom, left to right
    args:
        dt_boxes(array):detected text boxes with shape [4, 2]
    return:
        sorted boxes(array) with shape [4, 2]
    """
    num_boxes = dt_boxes.shape[0]
    sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
    _boxes = list(sorted_boxes)

    for i in range(num_boxes - 1):
        for j in range(i, 0, -1):
            if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and \
                    (_boxes[j + 1][0][0] < _boxes[j][0][0]):
                tmp = _boxes[j]
                _boxes[j] = _boxes[j + 1]
                _boxes[j + 1] = tmp
            else:
                break
    return _boxes


def main():
    dbnet = DBNET(MODEL_PATH=args.det_model_dir)
    crnn = CRNNHandle(model_path=args.rec_model_dir)
    anglenet = AngleNetHandle(model_path=args.cls_model_dir)
    warmup = True
    if args.warmup:
        warmup_file_list = get_image_file_list("./warmup_images_5/")
        warmup_file_rec_list = get_image_file_list("./warmup_images_rec/")
        warmup_file_cls = "./warmup_images_rec/ArT_2708_1.jpg"
        startwarm = time.time()
        for warmup_file in warmup_file_list:
            print(warmup_file)
            img_warm = cv2.imread(warmup_file)
            dt_boxes, scores = dbnet.process(img_warm)
        for warmup_file_rec in warmup_file_rec_list:
            print(warmup_file_rec)
            img_warm_rec = cv2.imread(warmup_file_rec)
            max_batnum = 24
            min_batnum = 8
            if os.environ.get("OCR_REC_MAX_BATNUM") is not None:
                max_batnum = int(os.environ.get("OCR_REC_MAX_BATNUM"))
            if os.environ.get("OCR_REC_MIN_BATNUM") is not None:
                min_batnum = int(os.environ.get("OCR_REC_MIN_BATNUM"))
            assert max_batnum / min_batnum == int(max_batnum / min_batnum), "max_batnum must be multiple of min_batnum."
        
            for bn in range(int(max_batnum / min_batnum)):
                img_rec_list = []
                for i in range(min_batnum * (bn + 1)):
                    img_rec_list.append(img_warm_rec)
                rec_res, _ = crnn(img_rec_list)
        warmup_img_cls = cv2.imread(warmup_file_cls)
        rec_angle = anglenet(warmup_img_cls)
        elapsewarm = time.time() - startwarm
        print("warmup time:", elapsewarm)

    image_file_list = get_image_file_list(args.img_dir)
    for image_file in image_file_list:
        print(image_file)
        img = cv2.imread(image_file)
        ori_im = img.copy()
        st = time.time()
        dt_boxes, scores = dbnet.process(img)
        print(len(dt_boxes))
        db_time = time.time()
        print("db time:", db_time - st)

        if dt_boxes is None:
            return None, None
        img_crop_list = []

        dt_boxes = sorted_boxes(dt_boxes)

        st_ang = time.time()
        for bno in range(len(dt_boxes)):
            tmp_box = copy.deepcopy(dt_boxes[bno])
            img_crop = get_rotate_crop_image(ori_im, tmp_box.astype(np.float32))
            
            if args.use_angle_cls:
                # img_pil = Image.fromarray(img_crop).convert("RGB")
                # rec_angle = anglenet.predict_rbg(img_pil)
                rec_angle = anglenet(img_crop)
                if not rec_angle:
                    img_crop = cv2.rotate(img_crop, cv2.ROTATE_180)
                    
            img_crop_list.append(img_crop)
        ed_ang = time.time()
        print("ang time:", ed_ang - st_ang)

        st_rec = time.time()
        rec_res, _ = crnn(img_crop_list)
        ed = time.time()
        print("rec time:", ed - st_rec)
        print("infer time:", ed - st)
        for i in range(len(img_crop_list)):
            print(rec_res[i])

if __name__ == "__main__":
    main()
