main.py 4.49 KB
Newer Older
chenxj's avatar
chenxj committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import cv2
import copy
import numpy as np
import json
import time
from utils import get_image_file_list
from utils import get_rotate_crop_image
from dbnet.dbnet_infer import DBNET
from crnn.CRNN import CRNNHandle
from angnet.angle import AngleNetHandle
from PIL import Image
import os
import argparse


def str2bool(v):
    return v.lower() in ("true", "t", "1")

parser = argparse.ArgumentParser()

parser.add_argument("--warmup", type=str2bool, default=False)
parser.add_argument("--use_angle_cls", type=str2bool, default=False)
parser.add_argument("--img_dir", type=str)
parser.add_argument("--det_model_dir", type=str)
parser.add_argument("--rec_model_dir", type=str)
parser.add_argument("--cls_model_dir", type=str)
args = parser.parse_args()

def sorted_boxes(dt_boxes):
    """
    Sort text boxes in order from top to bottom, left to right
    args:
        dt_boxes(array):detected text boxes with shape [4, 2]
    return:
        sorted boxes(array) with shape [4, 2]
    """
    num_boxes = dt_boxes.shape[0]
    sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
    _boxes = list(sorted_boxes)

    for i in range(num_boxes - 1):
        for j in range(i, 0, -1):
            if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and \
                    (_boxes[j + 1][0][0] < _boxes[j][0][0]):
                tmp = _boxes[j]
                _boxes[j] = _boxes[j + 1]
                _boxes[j + 1] = tmp
            else:
                break
    return _boxes


def main():
    dbnet = DBNET(MODEL_PATH=args.det_model_dir)
    crnn = CRNNHandle(model_path=args.rec_model_dir)
    anglenet = AngleNetHandle(model_path=args.cls_model_dir)
    warmup = True
    if args.warmup:
        warmup_file_list = get_image_file_list("./warmup_images_5/")
        warmup_file_rec_list = get_image_file_list("./warmup_images_rec/")
        warmup_file_cls = "./warmup_images_rec/ArT_2708_1.jpg"
        startwarm = time.time()
        for warmup_file in warmup_file_list:
            print(warmup_file)
            img_warm = cv2.imread(warmup_file)
            dt_boxes, scores = dbnet.process(img_warm)
        for warmup_file_rec in warmup_file_rec_list:
            print(warmup_file_rec)
            img_warm_rec = cv2.imread(warmup_file_rec)
            max_batnum = 24
            min_batnum = 8
            if os.environ.get("OCR_REC_MAX_BATNUM") is not None:
                max_batnum = int(os.environ.get("OCR_REC_MAX_BATNUM"))
            if os.environ.get("OCR_REC_MIN_BATNUM") is not None:
                min_batnum = int(os.environ.get("OCR_REC_MIN_BATNUM"))
            assert max_batnum / min_batnum == int(max_batnum / min_batnum), "max_batnum must be multiple of min_batnum."
        
            for bn in range(int(max_batnum / min_batnum)):
                img_rec_list = []
                for i in range(min_batnum * (bn + 1)):
                    img_rec_list.append(img_warm_rec)
                rec_res, _ = crnn(img_rec_list)
        warmup_img_cls = cv2.imread(warmup_file_cls)
        rec_angle = anglenet(warmup_img_cls)
        elapsewarm = time.time() - startwarm
        print("warmup time:", elapsewarm)

    image_file_list = get_image_file_list(args.img_dir)
    for image_file in image_file_list:
        print(image_file)
        img = cv2.imread(image_file)
        ori_im = img.copy()
        st = time.time()
        dt_boxes, scores = dbnet.process(img)
        print(len(dt_boxes))
        db_time = time.time()
        print("db time:", db_time - st)

        if dt_boxes is None:
            return None, None
        img_crop_list = []

        dt_boxes = sorted_boxes(dt_boxes)

        st_ang = time.time()
        for bno in range(len(dt_boxes)):
            tmp_box = copy.deepcopy(dt_boxes[bno])
            img_crop = get_rotate_crop_image(ori_im, tmp_box.astype(np.float32))
            
            if args.use_angle_cls:
                # img_pil = Image.fromarray(img_crop).convert("RGB")
                # rec_angle = anglenet.predict_rbg(img_pil)
                rec_angle = anglenet(img_crop)
                if not rec_angle:
                    img_crop = cv2.rotate(img_crop, cv2.ROTATE_180)
                    
            img_crop_list.append(img_crop)
        ed_ang = time.time()
        print("ang time:", ed_ang - st_ang)

        st_rec = time.time()
        rec_res, _ = crnn(img_crop_list)
        ed = time.time()
        print("rec time:", ed - st_rec)
        print("infer time:", ed - st)
        for i in range(len(img_crop_list)):
            print(rec_res[i])

if __name__ == "__main__":
    main()