predict_system.py 6.04 KB
Newer Older
LDOUBLEV's avatar
LDOUBLEV committed
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
15
import os
import sys
16
__dir__ = os.path.dirname(os.path.abspath(__file__))
17
sys.path.append(__dir__)
18
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
LDOUBLEV's avatar
LDOUBLEV committed
19

LDOUBLEV's avatar
LDOUBLEV committed
20
21
22
23
import cv2
import copy
import numpy as np
import time
LDOUBLEV's avatar
LDOUBLEV committed
24
from PIL import Image
WenmuZhou's avatar
WenmuZhou committed
25
import tools.infer.utility as utility
LDOUBLEV's avatar
LDOUBLEV committed
26
from tools.infer.utility import draw_ocr
WenmuZhou's avatar
WenmuZhou committed
27
28
29
30
import tools.infer.predict_rec as predict_rec
import tools.infer.predict_det as predict_det
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.utils.logging import get_logger
LDOUBLEV's avatar
LDOUBLEV committed
31
32
33
34
35
36
37
38


class TextSystem(object):
    def __init__(self, args):
        self.text_detector = predict_det.TextDetector(args)
        self.text_recognizer = predict_rec.TextRecognizer(args)

    def get_rotate_crop_image(self, img, points):
39
        '''
LDOUBLEV's avatar
LDOUBLEV committed
40
41
42
43
44
45
46
47
        img_height, img_width = img.shape[0:2]
        left = int(np.min(points[:, 0]))
        right = int(np.max(points[:, 0]))
        top = int(np.min(points[:, 1]))
        bottom = int(np.max(points[:, 1]))
        img_crop = img[top:bottom, left:right, :].copy()
        points[:, 0] = points[:, 0] - left
        points[:, 1] = points[:, 1] - top
48
        '''
LDOUBLEV's avatar
LDOUBLEV committed
49
50
51
52
53
54
55
56
57
        img_crop_width = int(
            max(
                np.linalg.norm(points[0] - points[1]),
                np.linalg.norm(points[2] - points[3])))
        img_crop_height = int(
            max(
                np.linalg.norm(points[0] - points[3]),
                np.linalg.norm(points[1] - points[2])))
        pts_std = np.float32([[0, 0], [img_crop_width, 0],
58
59
                              [img_crop_width, img_crop_height],
                              [0, img_crop_height]])
LDOUBLEV's avatar
LDOUBLEV committed
60
        M = cv2.getPerspectiveTransform(points, pts_std)
LDOUBLEV's avatar
LDOUBLEV committed
61
62
63
64
65
        dst_img = cv2.warpPerspective(
            img,
            M, (img_crop_width, img_crop_height),
            borderMode=cv2.BORDER_REPLICATE,
            flags=cv2.INTER_CUBIC)
LDOUBLEV's avatar
LDOUBLEV committed
66
67
68
69
70
71
72
73
74
75
76
77
78
79
        dst_img_height, dst_img_width = dst_img.shape[0:2]
        if dst_img_height * 1.0 / dst_img_width >= 1.5:
            dst_img = np.rot90(dst_img)
        return dst_img

    def print_draw_crop_rec_res(self, img_crop_list, rec_res):
        bbox_num = len(img_crop_list)
        for bno in range(bbox_num):
            cv2.imwrite("./output/img_crop_%d.jpg" % bno, img_crop_list[bno])
            print(bno, rec_res[bno])

    def __call__(self, img):
        ori_im = img.copy()
        dt_boxes, elapse = self.text_detector(img)
80
        print("dt_boxes num : {}, elapse : {}".format(len(dt_boxes), elapse))
LDOUBLEV's avatar
LDOUBLEV committed
81
82
83
        if dt_boxes is None:
            return None, None
        img_crop_list = []
84
85
86

        dt_boxes = sorted_boxes(dt_boxes)

LDOUBLEV's avatar
LDOUBLEV committed
87
88
89
90
91
        for bno in range(len(dt_boxes)):
            tmp_box = copy.deepcopy(dt_boxes[bno])
            img_crop = self.get_rotate_crop_image(ori_im, tmp_box)
            img_crop_list.append(img_crop)
        rec_res, elapse = self.text_recognizer(img_crop_list)
92
        print("rec_res num  : {}, elapse : {}".format(len(rec_res), elapse))
93
        # self.print_draw_crop_rec_res(img_crop_list, rec_res)
LDOUBLEV's avatar
LDOUBLEV committed
94
95
96
        return dt_boxes, rec_res


97
98
99
100
def sorted_boxes(dt_boxes):
    """
    Sort text boxes in order from top to bottom, left to right
    args:
tink2123's avatar
tink2123 committed
101
        dt_boxes(array):detected text boxes with shape [4, 2]
102
103
104
105
    return:
        sorted boxes(array) with shape [4, 2]
    """
    num_boxes = dt_boxes.shape[0]
106
    sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
107
108
109
110
111
112
113
114
115
116
117
    _boxes = list(sorted_boxes)

    for i in range(num_boxes - 1):
        if abs(_boxes[i+1][0][1] - _boxes[i][0][1]) < 10 and \
            (_boxes[i + 1][0][0] < _boxes[i][0][0]):
            tmp = _boxes[i]
            _boxes[i] = _boxes[i + 1]
            _boxes[i + 1] = tmp
    return _boxes


118
def main(args):
LDOUBLEV's avatar
LDOUBLEV committed
119
    image_file_list = get_image_file_list(args.image_dir)
LDOUBLEV's avatar
LDOUBLEV committed
120
    text_sys = TextSystem(args)
LDOUBLEV's avatar
LDOUBLEV committed
121
    is_visualize = True
dyning's avatar
dyning committed
122
    tackle_img_num = 0
LDOUBLEV's avatar
LDOUBLEV committed
123
    for image_file in image_file_list:
LDOUBLEV's avatar
LDOUBLEV committed
124
125
126
        img, flag = check_and_read_gif(image_file)
        if not flag:
            img = cv2.imread(image_file)
LDOUBLEV's avatar
LDOUBLEV committed
127
128
129
130
        if img is None:
            logger.info("error in loading image:{}".format(image_file))
            continue
        starttime = time.time()
LDOUBLEV's avatar
LDOUBLEV committed
131
132
        tackle_img_num += 1
        if not args.use_gpu and args.enable_mkldnn and tackle_img_num % 30 == 0:
dyning's avatar
dyning committed
133
            text_sys = TextSystem(args)
LDOUBLEV's avatar
LDOUBLEV committed
134
135
136
        dt_boxes, rec_res = text_sys(img)
        elapse = time.time() - starttime
        print("Predict time of %s: %.3fs" % (image_file, elapse))
LDOUBLEV's avatar
LDOUBLEV committed
137
138

        drop_score = 0.5
LDOUBLEV's avatar
LDOUBLEV committed
139
140
141
        dt_num = len(dt_boxes)
        for dno in range(dt_num):
            text, score = rec_res[dno]
LDOUBLEV's avatar
LDOUBLEV committed
142
            if score >= drop_score:
LDOUBLEV's avatar
LDOUBLEV committed
143
144
                text_str = "%s, %.3f" % (text, score)
                print(text_str)
LDOUBLEV's avatar
LDOUBLEV committed
145
146
147
148
149
150
151

        if is_visualize:
            image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
            boxes = dt_boxes
            txts = [rec_res[i][0] for i in range(len(rec_res))]
            scores = [rec_res[i][1] for i in range(len(rec_res))]

152
            draw_img = draw_ocr(
WenmuZhou's avatar
WenmuZhou committed
153
                image, boxes, txts, scores, drop_score=drop_score)
154
            draw_img_save = "./inference_results/"
LDOUBLEV's avatar
LDOUBLEV committed
155
156
157
158
            if not os.path.exists(draw_img_save):
                os.makedirs(draw_img_save)
            cv2.imwrite(
                os.path.join(draw_img_save, os.path.basename(image_file)),
dyning's avatar
dyning committed
159
                draw_img[:, :, ::-1])
160
161
            print("The visualized image saved in {}".format(
                os.path.join(draw_img_save, os.path.basename(image_file))))
162
163
164


if __name__ == "__main__":
WenmuZhou's avatar
WenmuZhou committed
165
    logger = get_logger()
166
    main(utility.parse_args())