import os
import sys

__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))

os.environ["FLAGS_allocator_strategy"] = 'auto_growth'

import cv2
import copy
import numpy as np
import math
import time
import traceback

import tools.infer.utility as utility
from ppocr.postprocess import build_post_process
from ppocr.utils.logging import get_logger
from ppocr.utils.utility import get_image_file_list, check_and_read_gif

logger = get_logger()


class TextClassifier(object):
    def __init__(self, args):
        self.cls_image_shape = [int(v) for v in args.cls_image_shape.split(",")]
        self.cls_batch_num = args.cls_batch_num
        self.cls_thresh = args.cls_thresh
        postprocess_params = {
            'name': 'ClsPostProcess',
            "label_list": args.label_list,
        }
        self.postprocess_op = build_post_process(postprocess_params)
        self.predictor, self.input_tensor, self.output_tensors, _ = \
            utility.create_predictor(args, 'cls', logger)
        self.use_onnx = args.use_onnx

    def resize_norm_img(self, img):
        imgC, imgH, imgW = self.cls_image_shape
        h = img.shape[0]
        w = img.shape[1]
        ratio = w / float(h)
        if math.ceil(imgH * ratio) > imgW:
            resized_w = imgW
        else:
            resized_w = int(math.ceil(imgH * ratio))
        resized_image = cv2.resize(img, (resized_w, imgH))
        resized_image = resized_image.astype('float32')
        if self.cls_image_shape[0] == 1:
            resized_image = resized_image / 255
            resized_image = resized_image[np.newaxis, :]
        else:
            resized_image = resized_image.transpose((2, 0, 1)) / 255
        resized_image -= 0.5
        resized_image /= 0.5
        padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
        padding_im[:, :, 0:resized_w] = resized_image
        return padding_im
    
    def resize_norm_img_section(self, img, max_wh_ratio):
        # print("rec resize for section")
        imgC, imgH, imgW = self.cls_image_shape

        assert imgC == img.shape[2]
        
        rec_precision_level = os.environ.get("OCR_REC_PRECISION")
        max_w = imgH * 48
        # max_w = 2304
        if rec_precision_level =='0':
            imgW = max_w
        elif rec_precision_level == '1':
            imgW = int((imgH * max_wh_ratio))
            if imgW <= max_w / 2:
                imgW = max_w / 2
            else:
                imgW = max_w
        elif rec_precision_level == '2':
            imgW = int((imgH * max_wh_ratio))
            if imgW <= max_w / 4:
                imgW = max_w / 4
            elif imgW > max_w / 4 and imgW <= max_w / 2:
                imgW = max_w / 2
            elif imgW > max_w / 2 and imgW <= 3 * max_w / 4:
                imgW = 3 * max_w / 4
            else:
                imgW = max_w
        else:
            imgW = int((imgH * max_wh_ratio))
            if imgW <= max_w / 6:
                imgW = max_w / 6
            elif imgW > max_w / 6 and imgW <= max_w / 3:
                imgW = max_w / 3
            elif imgW > max_w / 3 and imgW <= max_w / 2:
                imgW = max_w / 2
            elif imgW > max_w / 2 and imgW <= 2 * max_w / 3:
                imgW = 2 * max_w / 3
            elif imgW > 2 *max_w / 3 and imgW <= 5 * max_w / 6:
                imgW = 5 * max_w / 6
            else:
                imgW = max_w

        imgW = int(imgW)
        h, w = img.shape[:2]
        ratio = w / float(h)
        if math.ceil(imgH * ratio) > imgW:
            resized_w = imgW
        else:
            resized_w = int(math.ceil(imgH * ratio))
        resized_image = cv2.resize(img, (resized_w, imgH))
        resized_image = resized_image.astype('float32')
        resized_image = resized_image.transpose((2, 0, 1)) / 255
        resized_image -= 0.5
        resized_image /= 0.5
        padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
        padding_im[:, :, 0:resized_w] = resized_image
        return padding_im

    def __call__(self, img_list):
        img_list = copy.deepcopy(img_list)
        img_num = len(img_list)
        # Calculate the aspect ratio of all text bars
        width_list = []
        for img in img_list:
            width_list.append(img.shape[1] / float(img.shape[0]))
        # Sorting can speed up the cls process
        indices = np.argsort(np.array(width_list))

        cls_res = [['', 0.0]] * img_num
        if img_num <= 0:
            return cls_res, 0
        max_batnum = 24
        min_batnum = 8
        if os.environ.get("OCR_REC_MAX_BATNUM") is not None:
            max_batnum = int(os.environ.get("OCR_REC_MAX_BATNUM"))
        if os.environ.get("OCR_REC_MIN_BATNUM") is not None:
            min_batnum = int(os.environ.get("OCR_REC_MIN_BATNUM"))
        assert max_batnum / min_batnum == int(max_batnum / min_batnum), "max_batnum must be multiple of min_batnum."
        img_num_left = img_num
        img_no_count = 0
        st = time.time()
        if img_num_left > max_batnum:
            batch_num = max_batnum
            batch_num = int(batch_num)
            for beg_img_no in range(img_no_count, int(img_num_left / batch_num) * batch_num, batch_num):
                end_img_no = beg_img_no + batch_num
                norm_img_batch = []
                max_wh_ratio = 0
                for ino in range(beg_img_no, end_img_no):
                    h, w = img_list[indices[ino]].shape[0:2]
                    wh_ratio = w * 1.0 / h
                    max_wh_ratio = max(max_wh_ratio, wh_ratio)
                for ino in range(beg_img_no, end_img_no):
                    norm_img = self.resize_norm_img_section(img_list[indices[ino]], max_wh_ratio)
                    norm_img = norm_img[np.newaxis, :]
                    norm_img_batch.append(norm_img)

                norm_img_batch = np.concatenate(norm_img_batch, axis=0)
                norm_img_batch = norm_img_batch.copy()

                if self.use_onnx:
                    input_dict = {}
                    input_dict[self.input_tensor.name] = norm_img_batch
                    outputs = self.predictor.run(self.output_tensors, input_dict)
                    prob_out = outputs[0]
                else:
                    self.input_tensor.copy_from_cpu(norm_img_batch)
                    self.predictor.run()
                    prob_out = self.output_tensors[0].copy_to_cpu()
                    self.predictor.try_shrink_memory()
                cls_result = self.postprocess_op(prob_out)
                for rno in range(len(cls_result)):
                    label, score = cls_result[rno]
                    cls_res[indices[beg_img_no + rno]] = [label, score]
                    if '180' in label and score > self.cls_thresh:
                        img_list[indices[beg_img_no + rno]] = cv2.rotate(
                            img_list[indices[beg_img_no + rno]], 1)
            img_no_count = int(img_num_left / batch_num) * batch_num
            img_num_left = img_num_left - int(img_num_left / batch_num) * batch_num
            
        batch_num = math.ceil(img_num_left / min_batnum) * min_batnum
        batch_num = int(batch_num)
        Dnum = batch_num - img_num_left
        for dno in range(Dnum):
            indices = np.append(indices,img_num + dno)
            cls_res.append(['', 0.0])
        
        beg_img_no = img_no_count
        end_img_no = img_num
        norm_img_batch = []
        max_wh_ratio = 0
        for ino in range(beg_img_no, end_img_no):
            h, w = img_list[indices[ino]].shape[0:2]
            wh_ratio = w * 1.0 / h
            max_wh_ratio = max(max_wh_ratio, wh_ratio)
        for ino in range(beg_img_no, end_img_no):
            norm_img = self.resize_norm_img_section(img_list[indices[ino]], max_wh_ratio)
            norm_img = norm_img[np.newaxis, :]
            norm_img_batch.append(norm_img)

        norm_img_batch = np.concatenate(norm_img_batch)
        if norm_img_batch.shape[0] != batch_num:
            img_tmp = np.zeros((batch_num - norm_img_batch.shape[0], norm_img_batch.shape[1], norm_img_batch.shape[2], norm_img_batch.shape[3]), dtype=np.float32)
            norm_img_batch = np.concatenate([norm_img_batch, img_tmp])
        norm_img_batch = norm_img_batch.copy()
        
        if self.use_onnx:
            input_dict = {}
            input_dict[self.input_tensor.name] = norm_img_batch
            outputs = self.predictor.run(self.output_tensors, input_dict)
            prob_out = outputs[0]
        else:
            self.input_tensor.copy_from_cpu(norm_img_batch)
            self.predictor.run()
            prob_out = self.output_tensors[0].copy_to_cpu()
            self.predictor.try_shrink_memory()
        cls_result = self.postprocess_op(prob_out)
        for rno in range(len(cls_result)):
            label, score = cls_result[rno]
            cls_res[indices[beg_img_no + rno]] = [label, score]
            if '180' in label and score > self.cls_thresh and (beg_img_no + rno) < img_num:
                img_list[indices[beg_img_no + rno]] = cv2.rotate(
                    img_list[indices[beg_img_no + rno]], 1)
                        
        return img_list, cls_res, time.time() - st


class TextClassifier_warmup(object):
    def __init__(self, det_model_dir, cls_model_dir, rec_model_dir, use_onnx):
        cls_image_shape = "3, 48, 192"
        self.cls_image_shape = [int(v) for v in cls_image_shape.split(",")]
        self.cls_batch_num = 6
        self.cls_thresh = 0.9
        postprocess_params = {
            'name': 'ClsPostProcess',
            "label_list": ['0', '180'],
        }
        self.postprocess_op = build_post_process(postprocess_params)
        self.predictor, self.input_tensor, self.output_tensors, _ = \
            utility.create_predictor_warmup(det_model_dir, cls_model_dir, rec_model_dir, use_onnx, 'cls', logger)
        self.use_onnx = use_onnx

    def resize_norm_img(self, img):
        imgC, imgH, imgW = self.cls_image_shape
        h = img.shape[0]
        w = img.shape[1]
        ratio = w / float(h)
        if math.ceil(imgH * ratio) > imgW:
            resized_w = imgW
        else:
            resized_w = int(math.ceil(imgH * ratio))
        resized_image = cv2.resize(img, (resized_w, imgH))
        resized_image = resized_image.astype('float32')
        if self.cls_image_shape[0] == 1:
            resized_image = resized_image / 255
            resized_image = resized_image[np.newaxis, :]
        else:
            resized_image = resized_image.transpose((2, 0, 1)) / 255
        resized_image -= 0.5
        resized_image /= 0.5
        padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
        padding_im[:, :, 0:resized_w] = resized_image
        return padding_im
    
    def resize_norm_img_section(self, img, max_wh_ratio):
        # print("rec resize for section")
        imgC, imgH, imgW = self.cls_image_shape

        assert imgC == img.shape[2]
        
        rec_precision_level = os.environ.get("OCR_REC_PRECISION")
        max_w = imgH * 48
        # max_w = 2304
        if rec_precision_level =='0':
            imgW = max_w
        elif rec_precision_level == '1':
            imgW = int((imgH * max_wh_ratio))
            if imgW <= max_w / 2:
                imgW = max_w / 2
            else:
                imgW = max_w
        elif rec_precision_level == '2':
            imgW = int((imgH * max_wh_ratio))
            if imgW <= max_w / 4:
                imgW = max_w / 4
            elif imgW > max_w / 4 and imgW <= max_w / 2:
                imgW = max_w / 2
            elif imgW > max_w / 2 and imgW <= 3 * max_w / 4:
                imgW = 3 * max_w / 4
            else:
                imgW = max_w
        else:
            imgW = int((imgH * max_wh_ratio))
            if imgW <= max_w / 6:
                imgW = max_w / 6
            elif imgW > max_w / 6 and imgW <= max_w / 3:
                imgW = max_w / 3
            elif imgW > max_w / 3 and imgW <= max_w / 2:
                imgW = max_w / 2
            elif imgW > max_w / 2 and imgW <= 2 * max_w / 3:
                imgW = 2 * max_w / 3
            elif imgW > 2 *max_w / 3 and imgW <= 5 * max_w / 6:
                imgW = 5 * max_w / 6
            else:
                imgW = max_w

        imgW = int(imgW)
        h, w = img.shape[:2]
        ratio = w / float(h)
        if math.ceil(imgH * ratio) > imgW:
            resized_w = imgW
        else:
            resized_w = int(math.ceil(imgH * ratio))
        resized_image = cv2.resize(img, (resized_w, imgH))
        resized_image = resized_image.astype('float32')
        resized_image = resized_image.transpose((2, 0, 1)) / 255
        resized_image -= 0.5
        resized_image /= 0.5
        padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
        padding_im[:, :, 0:resized_w] = resized_image
        return padding_im

    def __call__(self, img_list):
        img_list = copy.deepcopy(img_list)
        img_num = len(img_list)
        # Calculate the aspect ratio of all text bars
        width_list = []
        for img in img_list:
            width_list.append(img.shape[1] / float(img.shape[0]))
        # Sorting can speed up the cls process
        indices = np.argsort(np.array(width_list))

        cls_res = [['', 0.0]] * img_num
        if img_num <= 0:
            return cls_res, 0
        max_batnum = 24
        min_batnum = 8
        if os.environ.get("OCR_REC_MAX_BATNUM") is not None:
            max_batnum = int(os.environ.get("OCR_REC_MAX_BATNUM"))
        if os.environ.get("OCR_REC_MIN_BATNUM") is not None:
            min_batnum = int(os.environ.get("OCR_REC_MIN_BATNUM"))
        assert max_batnum / min_batnum == int(max_batnum / min_batnum), "max_batnum must be multiple of min_batnum."
        img_num_left = img_num
        img_no_count = 0
        st = time.time()
        if img_num_left > max_batnum:
            batch_num = max_batnum
            batch_num = int(batch_num)
            for beg_img_no in range(img_no_count, int(img_num_left / batch_num) * batch_num, batch_num):
                end_img_no = beg_img_no + batch_num
                norm_img_batch = []
                max_wh_ratio = 0
                for ino in range(beg_img_no, end_img_no):
                    h, w = img_list[indices[ino]].shape[0:2]
                    wh_ratio = w * 1.0 / h
                    max_wh_ratio = max(max_wh_ratio, wh_ratio)
                for ino in range(beg_img_no, end_img_no):
                    norm_img = self.resize_norm_img_section(img_list[indices[ino]], max_wh_ratio)
                    norm_img = norm_img[np.newaxis, :]
                    norm_img_batch.append(norm_img)

                norm_img_batch = np.concatenate(norm_img_batch, axis=0)
                norm_img_batch = norm_img_batch.copy()

                if self.use_onnx:
                    input_dict = {}
                    input_dict[self.input_tensor.name] = norm_img_batch
                    outputs = self.predictor.run(self.output_tensors, input_dict)
                    prob_out = outputs[0]
                else:
                    self.input_tensor.copy_from_cpu(norm_img_batch)
                    self.predictor.run()
                    prob_out = self.output_tensors[0].copy_to_cpu()
                    self.predictor.try_shrink_memory()
                cls_result = self.postprocess_op(prob_out)
                for rno in range(len(cls_result)):
                    label, score = cls_result[rno]
                    cls_res[indices[beg_img_no + rno]] = [label, score]
                    if '180' in label and score > self.cls_thresh:
                        img_list[indices[beg_img_no + rno]] = cv2.rotate(
                            img_list[indices[beg_img_no + rno]], 1)
            img_no_count = int(img_num_left / batch_num) * batch_num
            img_num_left = img_num_left - int(img_num_left / batch_num) * batch_num
            
        batch_num = math.ceil(img_num_left / min_batnum) * min_batnum
        batch_num = int(batch_num)
        Dnum = batch_num - img_num_left
        for dno in range(Dnum):
            indices = np.append(indices,img_num + dno)
            cls_res.append(['', 0.0])
        
        beg_img_no = img_no_count
        end_img_no = img_num
        norm_img_batch = []
        max_wh_ratio = 0
        for ino in range(beg_img_no, end_img_no):
            h, w = img_list[indices[ino]].shape[0:2]
            wh_ratio = w * 1.0 / h
            max_wh_ratio = max(max_wh_ratio, wh_ratio)
        for ino in range(beg_img_no, end_img_no):
            norm_img = self.resize_norm_img_section(img_list[indices[ino]], max_wh_ratio)
            norm_img = norm_img[np.newaxis, :]
            norm_img_batch.append(norm_img)

        norm_img_batch = np.concatenate(norm_img_batch)
        if norm_img_batch.shape[0] != batch_num:
            img_tmp = np.zeros((batch_num - norm_img_batch.shape[0], norm_img_batch.shape[1], norm_img_batch.shape[2], norm_img_batch.shape[3]), dtype=np.float32)
            norm_img_batch = np.concatenate([norm_img_batch, img_tmp])
        norm_img_batch = norm_img_batch.copy()
        
        if self.use_onnx:
            input_dict = {}
            input_dict[self.input_tensor.name] = norm_img_batch
            outputs = self.predictor.run(self.output_tensors, input_dict)
            prob_out = outputs[0]
        else:
            self.input_tensor.copy_from_cpu(norm_img_batch)
            self.predictor.run()
            prob_out = self.output_tensors[0].copy_to_cpu()
            self.predictor.try_shrink_memory()
        cls_result = self.postprocess_op(prob_out)
        for rno in range(len(cls_result)):
            label, score = cls_result[rno]
            cls_res[indices[beg_img_no + rno]] = [label, score]
            if '180' in label and score > self.cls_thresh and (beg_img_no + rno) < img_num:
                img_list[indices[beg_img_no + rno]] = cv2.rotate(
                    img_list[indices[beg_img_no + rno]], 1)
                        
        return img_list, cls_res, time.time() - st
    
    
def main(args):
    image_file_list = get_image_file_list(args.image_dir)
    text_classifier = TextClassifier(args)
    valid_image_file_list = []
    img_list = []
    for image_file in image_file_list:
        img, flag = check_and_read_gif(image_file)
        if not flag:
            img = cv2.imread(image_file)
        if img is None:
            logger.info("error in loading image:{}".format(image_file))
            continue
        valid_image_file_list.append(image_file)
        img_list.append(img)
    try:
        img_list, cls_res, predict_time = text_classifier(img_list)
    except Exception as E:
        logger.info(traceback.format_exc())
        logger.info(E)
        exit()
    for ino in range(len(img_list)):
        logger.info("Predicts of {}:{}".format(valid_image_file_list[ino],
                                               cls_res[ino]))


if __name__ == "__main__":
    main(utility.parse_args())
