from PIL import  Image
import numpy as np
import cv2
import os
import sys

__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))

os.environ["FLAGS_allocator_strategy"] = 'auto_growth'

from keys import alphabetChinese as alphabet

import onnxruntime as rt
from util import strLabelConverter, resizeNormalize
import os
import time
import math

converter = strLabelConverter(''.join(alphabet))


def _check_image_file(path):
    img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif'}
    return any([path.lower().endswith(e) for e in img_end])

def get_image_file_list(img_file):
    imgs_lists = []
    if img_file is None or not os.path.exists(img_file):
        raise Exception("not found any img file in {}".format(img_file))

    if os.path.isfile(img_file) and _check_image_file(img_file):
        imgs_lists.append(img_file)
    elif os.path.isdir(img_file):
        for single_file in os.listdir(img_file):
            file_path = os.path.join(img_file, single_file)
            if os.path.isfile(file_path) and _check_image_file(file_path):
                imgs_lists.append(file_path)
    if len(imgs_lists) == 0:
        raise Exception("not found any img file in {}".format(img_file))
    imgs_lists = sorted(imgs_lists)
    return imgs_lists


def softmax(x):
    x_row_max = x.max(axis=-1)
    x_row_max = x_row_max.reshape(list(x.shape)[:-1]+[1])
    x = x - x_row_max
    x_exp = np.exp(x)
    x_exp_row_sum = x_exp.sum(axis=-1).reshape(list(x.shape)[:-1]+[1])
    softmax = x_exp / x_exp_row_sum
    return softmax


class CRNNHandle:
    def __init__(self, model_path):

        self.sess = rt.InferenceSession(model_path, providers=[('ROCMExecutionProvider', {'device_id': '3'}),'CPUExecutionProvider'])

    def predict(self, image):
        """
        预测
        """
        scale = image.size[1] * 1.0 / 32
        w = image.size[0] / scale
        w = int(w)
        transformer = resizeNormalize((w, 32))

        image = transformer(image)

        image = image.transpose(2, 0, 1)
        transformed_image = np.expand_dims(image, axis=0)

        preds = self.sess.run(["out"], {"input": transformed_image.astype(np.float32)})

        preds = preds[0]


        length  = preds.shape[0]
        preds = preds.reshape(length,-1)

        preds = np.argmax(preds,axis=1)

        preds = preds.reshape(-1)


        sim_pred = converter.decode(preds, length, raw=False)

        return sim_pred



    def predict_rbg(self, im):
        """
        预测
        """
        scale = im.size[1] * 1.0 / 32
        w = im.size[0] / scale
        w = int(w)

        img = im.resize((w, 32), Image.BILINEAR)
        img = np.array(img, dtype=np.float32)
        img -= 127.5
        img /= 127.5
        image = img.transpose(2, 0, 1)
        transformed_image = np.expand_dims(image, axis=0)

        preds = self.sess.run(["out"], {"input": transformed_image.astype(np.float32)})

        preds = preds[0]


        length  = preds.shape[0]
        preds = preds.reshape(length,-1)

        # preds = softmax(preds)


        preds = np.argmax(preds,axis=1)

        preds = preds.reshape(-1)

        sim_pred = converter.decode(preds, length, raw=False)

        return sim_pred
    
    
    def resize_norm_img_section(self, img, max_wh_ratio):
        # print("rec resize for section")
        imgC, imgH, imgW = 3, 32, 320
        
        rec_precision_level = os.environ.get("OCR_REC_PRECISION")
        max_w = imgH * 48
        # max_w = 2304
        if rec_precision_level =='0':
            imgW = max_w
        elif rec_precision_level == '1':
            imgW = int((imgH * max_wh_ratio))
            if imgW <= max_w / 2:
                imgW = max_w / 2
            else:
                imgW = max_w
        elif rec_precision_level == '2':
            imgW = int((imgH * max_wh_ratio))
            if imgW <= max_w / 4:
                imgW = max_w / 4
            elif imgW > max_w / 4 and imgW <= max_w / 2:
                imgW = max_w / 2
            elif imgW > max_w / 2 and imgW <= 3 * max_w / 4:
                imgW = 3 * max_w / 4
            else:
                imgW = max_w
        else:
            imgW = int((imgH * max_wh_ratio))
            if imgW <= max_w / 6:
                imgW = max_w / 6
            elif imgW > max_w / 6 and imgW <= max_w / 3:
                imgW = max_w / 3
            elif imgW > max_w / 3 and imgW <= max_w / 2:
                imgW = max_w / 2
            elif imgW > max_w / 2 and imgW <= 2 * max_w / 3:
                imgW = 2 * max_w / 3
            elif imgW > 2 *max_w / 3 and imgW <= 5 * max_w / 6:
                imgW = 5 * max_w / 6
            else:
                imgW = max_w

        imgW = int(imgW)
        h, w = img.shape[:2]
        ratio = w / float(h)
        if math.ceil(imgH * ratio) > imgW:
            resized_w = imgW
        else:
            resized_w = int(math.ceil(imgH * ratio))
        resized_image = cv2.resize(img, (resized_w, imgH))
        resized_image = resized_image.astype('float32')
        resized_image = resized_image.transpose((2, 0, 1)) / 255
        resized_image -= 0.5
        resized_image /= 0.5
        padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
        padding_im[:, :, 0:resized_w] = resized_image
        return padding_im

    def __call__(self, img_list):
        img_num = len(img_list)
        # Calculate the aspect ratio of all text bars
        width_list = []
        for img in img_list:
            width_list.append(img.shape[1] / float(img.shape[0]))
        # Sorting can speed up the recognition process
        indices = np.argsort(np.array(width_list))

        rec_res = [''] * img_num
        if img_num <= 0:
            return rec_res, 0
        max_batnum = 24
        min_batnum = 8
        if os.environ.get("OCR_REC_MAX_BATNUM") is not None:
            max_batnum = int(os.environ.get("OCR_REC_MAX_BATNUM"))
        if os.environ.get("OCR_REC_MIN_BATNUM") is not None:
            min_batnum = int(os.environ.get("OCR_REC_MIN_BATNUM"))
        assert max_batnum / min_batnum == int(max_batnum / min_batnum), "max_batnum must be multiple of min_batnum."
        img_num_left = img_num
        img_no_count = 0
        st = time.time()
        if img_num_left > max_batnum:
            batch_num = max_batnum
            batch_num = int(batch_num)
            for beg_img_no in range(img_no_count, int(img_num_left / batch_num) * batch_num, batch_num):
                end_img_no = beg_img_no + batch_num
                norm_img_batch = []
                max_wh_ratio = 0
                for ino in range(beg_img_no, end_img_no):
                    h, w = img_list[indices[ino]].shape[0:2]
                    wh_ratio = w * 1.0 / h
                    max_wh_ratio = max(max_wh_ratio, wh_ratio)
                for ino in range(beg_img_no, end_img_no):
                    norm_img = self.resize_norm_img_section(img_list[indices[ino]], max_wh_ratio)
                    norm_img = norm_img[np.newaxis, :]
                    norm_img_batch.append(norm_img)

                norm_img_batch = np.concatenate(norm_img_batch, axis=0)
                norm_img_batch = norm_img_batch.copy()

                input_dict = {}
                input_dict["input"] = norm_img_batch
                outputs = self.sess.run(["out"], input_dict)
                preds_ = outputs[0]
                
                for rno in range(batch_num):
                    preds = preds_[:,rno:rno + 1,:]
                    length  = preds.shape[0]
                    preds = preds.reshape(length,-1)

                    preds = np.argmax(preds,axis=1)

                    preds = preds.reshape(-1)

                    sim_pred = converter.decode(preds, length, raw=False)
                    rec_res[indices[beg_img_no + rno]] = sim_pred
                    
            img_no_count = int(img_num_left / batch_num) * batch_num
            img_num_left = img_num_left - int(img_num_left / batch_num) * batch_num
                    
        batch_num = math.ceil(img_num_left / min_batnum) * min_batnum
        batch_num = int(batch_num)
        Dnum = batch_num - img_num_left
        for dno in range(Dnum):
            indices = np.append(indices,img_num + dno)
            rec_res.append('')
        
        beg_img_no = img_no_count
        end_img_no = img_num
        norm_img_batch = []
        max_wh_ratio = 0
        for ino in range(beg_img_no, end_img_no):
            h, w = img_list[indices[ino]].shape[0:2]
            wh_ratio = w * 1.0 / h
            max_wh_ratio = max(max_wh_ratio, wh_ratio)
        for ino in range(beg_img_no, end_img_no):
            norm_img = self.resize_norm_img_section(img_list[indices[ino]], max_wh_ratio)
            norm_img = norm_img[np.newaxis, :]
            norm_img_batch.append(norm_img)

        norm_img_batch = np.concatenate(norm_img_batch)
        if norm_img_batch.shape[0] != batch_num:
            img_tmp = np.zeros((batch_num - norm_img_batch.shape[0], norm_img_batch.shape[1], norm_img_batch.shape[2], norm_img_batch.shape[3]), dtype=np.float32)
            norm_img_batch = np.concatenate([norm_img_batch, img_tmp])
        norm_img_batch = norm_img_batch.copy()

        input_dict = {}
        input_dict["input"] = norm_img_batch
        outputs = self.sess.run(["out"], input_dict)
        preds_ = outputs[0]
        
        for rno in range(img_num - img_no_count):
            preds = preds_[:,rno:rno + 1,:]
            length  = preds.shape[0]
            preds = preds.reshape(length,-1)

            preds = np.argmax(preds,axis=1)

            preds = preds.reshape(-1)

            sim_pred = converter.decode(preds, length, raw=False)
            rec_res[indices[beg_img_no + rno]] = sim_pred

        return rec_res, time.time() - st



if __name__ == "__main__":
    image_file_list = get_image_file_list("warmup_images_rec")
    crnn_handle = CRNNHandle(model_path="./models/crnn_lite_lstm.onnx")
    
    img_list = []
    for image_file in image_file_list:
        img = cv2.imread(image_file)
        img_list.append(img)
        
        im = Image.open(image_file)
        print(crnn_handle.predict_rbg(im))
        
        # img_list_tmp = [img]
        # rec_res, _ = crnn_handle(img_list_tmp)
        # print(rec_res[0])
        
    rec_res, _ = crnn_handle(img_list)
    for i in range(len(image_file_list)):
        print(rec_res[i])
    
