import cv2 import copy import numpy as np import json import time from utils import get_image_file_list from utils import get_rotate_crop_image from dbnet.dbnet_infer import DBNET from crnn.CRNN import CRNNHandle from angnet.angle import AngleNetHandle from PIL import Image import os import argparse def str2bool(v): return v.lower() in ("true", "t", "1") parser = argparse.ArgumentParser() parser.add_argument("--warmup", type=str2bool, default=False) parser.add_argument("--use_angle_cls", type=str2bool, default=False) parser.add_argument("--img_dir", type=str) parser.add_argument("--det_model_dir", type=str) parser.add_argument("--rec_model_dir", type=str) parser.add_argument("--cls_model_dir", type=str) args = parser.parse_args() def sorted_boxes(dt_boxes): """ Sort text boxes in order from top to bottom, left to right args: dt_boxes(array):detected text boxes with shape [4, 2] return: sorted boxes(array) with shape [4, 2] """ num_boxes = dt_boxes.shape[0] sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0])) _boxes = list(sorted_boxes) for i in range(num_boxes - 1): for j in range(i, 0, -1): if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and \ (_boxes[j + 1][0][0] < _boxes[j][0][0]): tmp = _boxes[j] _boxes[j] = _boxes[j + 1] _boxes[j + 1] = tmp else: break return _boxes def main(): dbnet = DBNET(MODEL_PATH=args.det_model_dir) crnn = CRNNHandle(model_path=args.rec_model_dir) anglenet = AngleNetHandle(model_path=args.cls_model_dir) warmup = True if args.warmup: warmup_file_list = get_image_file_list("./warmup_images_5/") warmup_file_rec_list = get_image_file_list("./warmup_images_rec/") warmup_file_cls = "./warmup_images_rec/ArT_2708_1.jpg" startwarm = time.time() for warmup_file in warmup_file_list: print(warmup_file) img_warm = cv2.imread(warmup_file) dt_boxes, scores = dbnet.process(img_warm) for warmup_file_rec in warmup_file_rec_list: print(warmup_file_rec) img_warm_rec = cv2.imread(warmup_file_rec) max_batnum = 24 min_batnum = 8 if os.environ.get("OCR_REC_MAX_BATNUM") is not None: max_batnum = int(os.environ.get("OCR_REC_MAX_BATNUM")) if os.environ.get("OCR_REC_MIN_BATNUM") is not None: min_batnum = int(os.environ.get("OCR_REC_MIN_BATNUM")) assert max_batnum / min_batnum == int(max_batnum / min_batnum), "max_batnum must be multiple of min_batnum." for bn in range(int(max_batnum / min_batnum)): img_rec_list = [] for i in range(min_batnum * (bn + 1)): img_rec_list.append(img_warm_rec) rec_res, _ = crnn(img_rec_list) warmup_img_cls = cv2.imread(warmup_file_cls) rec_angle = anglenet(warmup_img_cls) elapsewarm = time.time() - startwarm print("warmup time:", elapsewarm) image_file_list = get_image_file_list(args.img_dir) for image_file in image_file_list: print(image_file) img = cv2.imread(image_file) ori_im = img.copy() st = time.time() dt_boxes, scores = dbnet.process(img) print(len(dt_boxes)) db_time = time.time() print("db time:", db_time - st) if dt_boxes is None: return None, None img_crop_list = [] dt_boxes = sorted_boxes(dt_boxes) st_ang = time.time() for bno in range(len(dt_boxes)): tmp_box = copy.deepcopy(dt_boxes[bno]) img_crop = get_rotate_crop_image(ori_im, tmp_box.astype(np.float32)) if args.use_angle_cls: # img_pil = Image.fromarray(img_crop).convert("RGB") # rec_angle = anglenet.predict_rbg(img_pil) rec_angle = anglenet(img_crop) if not rec_angle: img_crop = cv2.rotate(img_crop, cv2.ROTATE_180) img_crop_list.append(img_crop) ed_ang = time.time() print("ang time:", ed_ang - st_ang) st_rec = time.time() rec_res, _ = crnn(img_crop_list) ed = time.time() print("rec time:", ed - st_rec) print("infer time:", ed - st) for i in range(len(img_crop_list)): print(rec_res[i]) if __name__ == "__main__": main()