predict_system.py 6.8 KB
Newer Older
LDOUBLEV's avatar
LDOUBLEV committed
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
15
import os
import sys
LDOUBLEV's avatar
LDOUBLEV committed
16
import subprocess
WenmuZhou's avatar
WenmuZhou committed
17

18
__dir__ = os.path.dirname(os.path.abspath(__file__))
19
sys.path.append(__dir__)
20
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
LDOUBLEV's avatar
LDOUBLEV committed
21

22
23
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'

LDOUBLEV's avatar
LDOUBLEV committed
24
25
26
27
import cv2
import copy
import numpy as np
import time
WenmuZhou's avatar
WenmuZhou committed
28
import logging
LDOUBLEV's avatar
LDOUBLEV committed
29
from PIL import Image
WenmuZhou's avatar
WenmuZhou committed
30
31
32
import tools.infer.utility as utility
import tools.infer.predict_rec as predict_rec
import tools.infer.predict_det as predict_det
WenmuZhou's avatar
WenmuZhou committed
33
import tools.infer.predict_cls as predict_cls
WenmuZhou's avatar
WenmuZhou committed
34
35
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.utils.logging import get_logger
WenmuZhou's avatar
WenmuZhou committed
36
from tools.infer.utility import draw_ocr_box_txt, get_rotate_crop_image
WenmuZhou's avatar
WenmuZhou committed
37
38
logger = get_logger()

LDOUBLEV's avatar
LDOUBLEV committed
39
40
41

class TextSystem(object):
    def __init__(self, args):
WenmuZhou's avatar
WenmuZhou committed
42
43
44
        if not args.show_log:
            logger.setLevel(logging.INFO)

LDOUBLEV's avatar
LDOUBLEV committed
45
46
        self.text_detector = predict_det.TextDetector(args)
        self.text_recognizer = predict_rec.TextRecognizer(args)
WenmuZhou's avatar
WenmuZhou committed
47
        self.use_angle_cls = args.use_angle_cls
WenmuZhou's avatar
WenmuZhou committed
48
        self.drop_score = args.drop_score
WenmuZhou's avatar
WenmuZhou committed
49
50
        if self.use_angle_cls:
            self.text_classifier = predict_cls.TextClassifier(args)
LDOUBLEV's avatar
LDOUBLEV committed
51
52
53
54
55

    def print_draw_crop_rec_res(self, img_crop_list, rec_res):
        bbox_num = len(img_crop_list)
        for bno in range(bbox_num):
            cv2.imwrite("./output/img_crop_%d.jpg" % bno, img_crop_list[bno])
WenmuZhou's avatar
WenmuZhou committed
56
            logger.info(bno, rec_res[bno])
LDOUBLEV's avatar
LDOUBLEV committed
57

58
    def __call__(self, img, cls=True):
LDOUBLEV's avatar
LDOUBLEV committed
59
60
        ori_im = img.copy()
        dt_boxes, elapse = self.text_detector(img)
LDOUBLEV's avatar
LDOUBLEV committed
61

WenmuZhou's avatar
WenmuZhou committed
62
        logger.debug("dt_boxes num : {}, elapse : {}".format(
WenmuZhou's avatar
WenmuZhou committed
63
            len(dt_boxes), elapse))
LDOUBLEV's avatar
LDOUBLEV committed
64
65
66
        if dt_boxes is None:
            return None, None
        img_crop_list = []
67
68
69

        dt_boxes = sorted_boxes(dt_boxes)

LDOUBLEV's avatar
LDOUBLEV committed
70
71
        for bno in range(len(dt_boxes)):
            tmp_box = copy.deepcopy(dt_boxes[bno])
WenmuZhou's avatar
WenmuZhou committed
72
            img_crop = get_rotate_crop_image(ori_im, tmp_box)
LDOUBLEV's avatar
LDOUBLEV committed
73
            img_crop_list.append(img_crop)
74
        if self.use_angle_cls and cls:
WenmuZhou's avatar
WenmuZhou committed
75
76
            img_crop_list, angle_list, elapse = self.text_classifier(
                img_crop_list)
WenmuZhou's avatar
WenmuZhou committed
77
            logger.debug("cls num  : {}, elapse : {}".format(
WenmuZhou's avatar
WenmuZhou committed
78
79
                len(img_crop_list), elapse))

LDOUBLEV's avatar
LDOUBLEV committed
80
        rec_res, elapse = self.text_recognizer(img_crop_list)
WenmuZhou's avatar
WenmuZhou committed
81
        logger.debug("rec_res num  : {}, elapse : {}".format(
WenmuZhou's avatar
WenmuZhou committed
82
            len(rec_res), elapse))
83
        # self.print_draw_crop_rec_res(img_crop_list, rec_res)
WenmuZhou's avatar
WenmuZhou committed
84
85
86
87
88
89
90
        filter_boxes, filter_rec_res = [], []
        for box, rec_reuslt in zip(dt_boxes, rec_res):
            text, score = rec_reuslt
            if score >= self.drop_score:
                filter_boxes.append(box)
                filter_rec_res.append(rec_reuslt)
        return filter_boxes, filter_rec_res
LDOUBLEV's avatar
LDOUBLEV committed
91
92


93
94
95
96
def sorted_boxes(dt_boxes):
    """
    Sort text boxes in order from top to bottom, left to right
    args:
tink2123's avatar
tink2123 committed
97
        dt_boxes(array):detected text boxes with shape [4, 2]
98
99
100
101
    return:
        sorted boxes(array) with shape [4, 2]
    """
    num_boxes = dt_boxes.shape[0]
102
    sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
103
104
105
    _boxes = list(sorted_boxes)

    for i in range(num_boxes - 1):
WenmuZhou's avatar
WenmuZhou committed
106
107
        if abs(_boxes[i + 1][0][1] - _boxes[i][0][1]) < 10 and \
                (_boxes[i + 1][0][0] < _boxes[i][0][0]):
108
109
110
111
112
113
            tmp = _boxes[i]
            _boxes[i] = _boxes[i + 1]
            _boxes[i + 1] = tmp
    return _boxes


114
def main(args):
LDOUBLEV's avatar
LDOUBLEV committed
115
    image_file_list = get_image_file_list(args.image_dir)
LDOUBLEV's avatar
LDOUBLEV committed
116
    image_file_list = image_file_list[args.process_id::args.total_process_num]
LDOUBLEV's avatar
LDOUBLEV committed
117
    text_sys = TextSystem(args)
LDOUBLEV's avatar
LDOUBLEV committed
118
    is_visualize = True
WenmuZhou's avatar
WenmuZhou committed
119
    font_path = args.vis_font_path
WenmuZhou's avatar
WenmuZhou committed
120
    drop_score = args.drop_score
Double_V's avatar
Double_V committed
121

LDOUBLEV's avatar
LDOUBLEV committed
122
123
124
125
126
    # warm up 10 times
    if args.warmup:
        img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8)
        for i in range(10):
            res = text_sys(img)
WenmuZhou's avatar
WenmuZhou committed
127

LDOUBLEV's avatar
LDOUBLEV committed
128
129
130
131
132
    total_time = 0
    cpu_mem, gpu_mem, gpu_util = 0, 0, 0
    _st = time.time()
    count = 0
    for idx, image_file in enumerate(image_file_list):
LDOUBLEV's avatar
LDOUBLEV committed
133

LDOUBLEV's avatar
LDOUBLEV committed
134
135
136
        img, flag = check_and_read_gif(image_file)
        if not flag:
            img = cv2.imread(image_file)
LDOUBLEV's avatar
LDOUBLEV committed
137
        if img is None:
138
            logger.info("error in loading image:{}".format(image_file))
LDOUBLEV's avatar
LDOUBLEV committed
139
140
141
142
            continue
        starttime = time.time()
        dt_boxes, rec_res = text_sys(img)
        elapse = time.time() - starttime
LDOUBLEV's avatar
LDOUBLEV committed
143
        total_time += elapse
LDOUBLEV's avatar
LDOUBLEV committed
144

LDOUBLEV's avatar
LDOUBLEV committed
145
146
        logger.info(
            str(idx) + "  Predict time of %s: %.3fs" % (image_file, elapse))
WenmuZhou's avatar
WenmuZhou committed
147
148
        for text, score in rec_res:
            logger.info("{}, {:.3f}".format(text, score))
LDOUBLEV's avatar
LDOUBLEV committed
149
150
151
152
153
154
155

        if is_visualize:
            image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
            boxes = dt_boxes
            txts = [rec_res[i][0] for i in range(len(rec_res))]
            scores = [rec_res[i][1] for i in range(len(rec_res))]

WenmuZhou's avatar
WenmuZhou committed
156
157
158
159
160
161
162
            draw_img = draw_ocr_box_txt(
                image,
                boxes,
                txts,
                scores,
                drop_score=drop_score,
                font_path=font_path)
163
            draw_img_save = "./inference_results/"
LDOUBLEV's avatar
LDOUBLEV committed
164
165
            if not os.path.exists(draw_img_save):
                os.makedirs(draw_img_save)
LDOUBLEV's avatar
LDOUBLEV committed
166
167
            if flag:
                image_file = image_file[:-3] + "png"
LDOUBLEV's avatar
LDOUBLEV committed
168
169
            cv2.imwrite(
                os.path.join(draw_img_save, os.path.basename(image_file)),
dyning's avatar
dyning committed
170
                draw_img[:, :, ::-1])
WenmuZhou's avatar
WenmuZhou committed
171
            logger.info("The visualized image saved in {}".format(
172
                os.path.join(draw_img_save, os.path.basename(image_file))))
173

LDOUBLEV's avatar
LDOUBLEV committed
174
175
    logger.info("The predict total time is {}".format(time.time() - _st))
    logger.info("\nThe predict total time is {}".format(total_time))
LDOUBLEV's avatar
LDOUBLEV committed
176
177
178
    if args.benchmark:
        text_sys.text_detector.autolog.report()
        text_sys.text_recognizer.autolog.report()
LDOUBLEV's avatar
LDOUBLEV committed
179

LDOUBLEV's avatar
LDOUBLEV committed
180

LDOUBLEV's avatar
LDOUBLEV committed
181
if __name__ == "__main__":
LDOUBLEV's avatar
LDOUBLEV committed
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
    args = utility.parse_args()
    if args.use_mp:
        p_list = []
        total_process_num = args.total_process_num
        for process_id in range(total_process_num):
            cmd = [sys.executable, "-u"] + sys.argv + [
                "--process_id={}".format(process_id),
                "--use_mp={}".format(False)
            ]
            p = subprocess.Popen(cmd, stdout=sys.stdout, stderr=sys.stdout)
            p_list.append(p)
        for p in p_list:
            p.wait()
    else:
        main(args)