predict_det.py 11.8 KB
Newer Older
LDOUBLEV's avatar
LDOUBLEV committed
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
LDOUBLEV's avatar
LDOUBLEV committed
14
15
import os
import sys
WenmuZhou's avatar
WenmuZhou committed
16

17
__dir__ = os.path.dirname(os.path.abspath(__file__))
LDOUBLEV's avatar
LDOUBLEV committed
18
sys.path.append(__dir__)
littletomatodonkey's avatar
littletomatodonkey committed
19
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
LDOUBLEV's avatar
LDOUBLEV committed
20

21
22
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'

23
24
25
26
import cv2
import numpy as np
import time
import sys
LDOUBLEV's avatar
fix  
LDOUBLEV committed
27
from scipy.spatial import distance as dist
28

LDOUBLEV's avatar
LDOUBLEV committed
29
import tools.infer.utility as utility
WenmuZhou's avatar
WenmuZhou committed
30
from ppocr.utils.logging import get_logger
LDOUBLEV's avatar
LDOUBLEV committed
31
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
WenmuZhou's avatar
WenmuZhou committed
32
33
from ppocr.data import create_operators, transform
from ppocr.postprocess import build_post_process
LDOUBLEV's avatar
LDOUBLEV committed
34
import json
WenmuZhou's avatar
WenmuZhou committed
35
36
logger = get_logger()

LDOUBLEV's avatar
LDOUBLEV committed
37
38
39

class TextDetector(object):
    def __init__(self, args):
LDOUBLEV's avatar
LDOUBLEV committed
40
        self.args = args
LDOUBLEV's avatar
LDOUBLEV committed
41
        self.det_algorithm = args.det_algorithm
tink2123's avatar
tink2123 committed
42
        self.use_onnx = args.use_onnx
MissPenguin's avatar
MissPenguin committed
43
        pre_process_list = [{
44
45
            'DetResizeForTest': {
                'limit_side_len': args.det_limit_side_len,
WenmuZhou's avatar
WenmuZhou committed
46
                'limit_type': args.det_limit_type,
47
            }
MissPenguin's avatar
MissPenguin committed
48
49
50
51
52
53
54
55
56
57
58
59
60
61
        }, {
            'NormalizeImage': {
                'std': [0.229, 0.224, 0.225],
                'mean': [0.485, 0.456, 0.406],
                'scale': '1./255.',
                'order': 'hwc'
            }
        }, {
            'ToCHWImage': None
        }, {
            'KeepKeys': {
                'keep_keys': ['image', 'shape']
            }
        }]
LDOUBLEV's avatar
LDOUBLEV committed
62
63
        postprocess_params = {}
        if self.det_algorithm == "DB":
WenmuZhou's avatar
WenmuZhou committed
64
            postprocess_params['name'] = 'DBPostProcess'
LDOUBLEV's avatar
LDOUBLEV committed
65
66
67
            postprocess_params["thresh"] = args.det_db_thresh
            postprocess_params["box_thresh"] = args.det_db_box_thresh
            postprocess_params["max_candidates"] = 1000
68
            postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
LDOUBLEV's avatar
LDOUBLEV committed
69
            postprocess_params["use_dilation"] = args.use_dilation
littletomatodonkey's avatar
littletomatodonkey committed
70
            postprocess_params["score_mode"] = args.det_db_score_mode
MissPenguin's avatar
MissPenguin committed
71
        elif self.det_algorithm == "EAST":
WenmuZhou's avatar
WenmuZhou committed
72
            postprocess_params['name'] = 'EASTPostProcess'
MissPenguin's avatar
MissPenguin committed
73
74
75
76
            postprocess_params["score_thresh"] = args.det_east_score_thresh
            postprocess_params["cover_thresh"] = args.det_east_cover_thresh
            postprocess_params["nms_thresh"] = args.det_east_nms_thresh
        elif self.det_algorithm == "SAST":
MissPenguin's avatar
MissPenguin committed
77
            pre_process_list[0] = {
WenmuZhou's avatar
WenmuZhou committed
78
79
80
                'DetResizeForTest': {
                    'resize_long': args.det_limit_side_len
                }
MissPenguin's avatar
MissPenguin committed
81
            }
WenmuZhou's avatar
WenmuZhou committed
82
            postprocess_params['name'] = 'SASTPostProcess'
MissPenguin's avatar
MissPenguin committed
83
84
85
86
87
88
89
90
91
92
93
            postprocess_params["score_thresh"] = args.det_sast_score_thresh
            postprocess_params["nms_thresh"] = args.det_sast_nms_thresh
            self.det_sast_polygon = args.det_sast_polygon
            if self.det_sast_polygon:
                postprocess_params["sample_pts_num"] = 6
                postprocess_params["expand_scale"] = 1.2
                postprocess_params["shrink_ratio_of_width"] = 0.2
            else:
                postprocess_params["sample_pts_num"] = 2
                postprocess_params["expand_scale"] = 1.0
                postprocess_params["shrink_ratio_of_width"] = 0.3
WenmuZhou's avatar
WenmuZhou committed
94
95
96
97
98
99
100
101
        elif self.det_algorithm == "PSE":
            postprocess_params['name'] = 'PSEPostProcess'
            postprocess_params["thresh"] = args.det_pse_thresh
            postprocess_params["box_thresh"] = args.det_pse_box_thresh
            postprocess_params["min_area"] = args.det_pse_min_area
            postprocess_params["box_type"] = args.det_pse_box_type
            postprocess_params["scale"] = args.det_pse_scale
            self.det_pse_box_type = args.det_pse_box_type
WenmuZhou's avatar
WenmuZhou committed
102
103
104
105
106
107
108
109
110
111
112
113
        elif self.det_algorithm == "FCE":
            pre_process_list[0] = {
                'DetResizeForTest': {
                    'rescale_img': [1080, 736]
                }
            }
            postprocess_params['name'] = 'FCEPostProcess'
            postprocess_params["scales"] = args.scales
            postprocess_params["alpha"] = args.alpha
            postprocess_params["beta"] = args.beta
            postprocess_params["fourier_degree"] = args.fourier_degree
            postprocess_params["box_type"] = args.det_fce_box_type
LDOUBLEV's avatar
LDOUBLEV committed
114
115
116
        else:
            logger.info("unknown det_algorithm:{}".format(self.det_algorithm))
            sys.exit(0)
117

WenmuZhou's avatar
WenmuZhou committed
118
119
        self.preprocess_op = create_operators(pre_process_list)
        self.postprocess_op = build_post_process(postprocess_params)
LDOUBLEV's avatar
LDOUBLEV committed
120
121
122
        self.predictor, self.input_tensor, self.output_tensors, self.config = utility.create_predictor(
            args, 'det', logger)

123
124
125
126
127
128
129
130
131
132
        if self.use_onnx:
            img_h, img_w = self.input_tensor.shape[2:]
            if img_h is not None and img_w is not None and img_h > 0 and img_w > 0:
                pre_process_list[0] = {
                    'DetResizeForTest': {
                        'image_shape': [img_h, img_w]
                    }
                }
        self.preprocess_op = create_operators(pre_process_list)

Double_V's avatar
Double_V committed
133
        if args.benchmark:
Double_V's avatar
Double_V committed
134
            import auto_log
Double_V's avatar
Double_V committed
135
            pid = os.getpid()
LDOUBLEV's avatar
LDOUBLEV committed
136
            gpu_id = utility.get_infer_gpuid()
Double_V's avatar
Double_V committed
137
138
139
140
141
            self.autolog = auto_log.AutoLogger(
                model_name="det",
                model_precision=args.precision,
                batch_size=1,
                data_shape="dynamic",
LDOUBLEV's avatar
LDOUBLEV committed
142
                save_path=None,
Double_V's avatar
Double_V committed
143
144
145
                inference_config=self.config,
                pids=pid,
                process_name=None,
146
                gpu_ids=gpu_id if args.use_gpu else None,
Double_V's avatar
Double_V committed
147
148
149
                time_keys=[
                    'preprocess_time', 'inference_time', 'postprocess_time'
                ],
150
                warmup=2,
LDOUBLEV's avatar
LDOUBLEV committed
151
                logger=logger)
LDOUBLEV's avatar
LDOUBLEV committed
152

LDOUBLEV's avatar
LDOUBLEV committed
153
    def order_points_clockwise(self, pts):
LDOUBLEV's avatar
fix  
LDOUBLEV committed
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
        """
        refer to :https://github.com/PyImageSearch/imutils/blob/9f740a53bcc2ed7eba2558afed8b4c17fd8a1d4c/imutils/perspective.py#L9
        """
        # sort the points based on their x-coordinates
        xSorted = pts[np.argsort(pts[:, 0]), :]

        leftMost = xSorted[:2, :]
        rightMost = xSorted[2:, :]

        leftMost = leftMost[np.argsort(leftMost[:, 1]), :]
        (tl, bl) = leftMost

        D = dist.cdist(tl[np.newaxis], rightMost, "euclidean")[0]
        (br, tr) = rightMost[np.argsort(D)[::-1], :]

        return np.array([tl, tr, br, bl], dtype="float32")
WenmuZhou's avatar
WenmuZhou committed
170

dyning's avatar
dyning committed
171
    def clip_det_res(self, points, img_height, img_width):
172
        for pno in range(points.shape[0]):
dyning's avatar
dyning committed
173
174
            points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
            points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
LDOUBLEV's avatar
LDOUBLEV committed
175
176
177
178
179
180
181
        return points

    def filter_tag_det_res(self, dt_boxes, image_shape):
        img_height, img_width = image_shape[0:2]
        dt_boxes_new = []
        for box in dt_boxes:
            box = self.order_points_clockwise(box)
dyning's avatar
dyning committed
182
            box = self.clip_det_res(box, img_height, img_width)
LDOUBLEV's avatar
LDOUBLEV committed
183
184
            rect_width = int(np.linalg.norm(box[0] - box[1]))
            rect_height = int(np.linalg.norm(box[0] - box[3]))
MissPenguin's avatar
MissPenguin committed
185
            if rect_width <= 3 or rect_height <= 3:
LDOUBLEV's avatar
LDOUBLEV committed
186
187
188
189
190
                continue
            dt_boxes_new.append(box)
        dt_boxes = np.array(dt_boxes_new)
        return dt_boxes

191
192
193
194
195
196
197
198
    def filter_tag_det_res_only_clip(self, dt_boxes, image_shape):
        img_height, img_width = image_shape[0:2]
        dt_boxes_new = []
        for box in dt_boxes:
            box = self.clip_det_res(box, img_height, img_width)
            dt_boxes_new.append(box)
        dt_boxes = np.array(dt_boxes_new)
        return dt_boxes
199

LDOUBLEV's avatar
LDOUBLEV committed
200
201
    def __call__(self, img):
        ori_im = img.copy()
WenmuZhou's avatar
WenmuZhou committed
202
        data = {'image': img}
LDOUBLEV's avatar
LDOUBLEV committed
203
204

        st = time.time()
LDOUBLEV's avatar
LDOUBLEV committed
205

littletomatodonkey's avatar
littletomatodonkey committed
206
        if self.args.benchmark:
Double_V's avatar
Double_V committed
207
            self.autolog.times.start()
LDOUBLEV's avatar
LDOUBLEV committed
208

WenmuZhou's avatar
WenmuZhou committed
209
210
211
        data = transform(data, self.preprocess_op)
        img, shape_list = data
        if img is None:
LDOUBLEV's avatar
LDOUBLEV committed
212
            return None, 0
WenmuZhou's avatar
WenmuZhou committed
213
214
        img = np.expand_dims(img, axis=0)
        shape_list = np.expand_dims(shape_list, axis=0)
215
        img = img.copy()
LDOUBLEV's avatar
LDOUBLEV committed
216

littletomatodonkey's avatar
littletomatodonkey committed
217
        if self.args.benchmark:
Double_V's avatar
Double_V committed
218
            self.autolog.times.stamp()
tink2123's avatar
tink2123 committed
219
220
221
222
223
224
225
226
227
228
229
230
231
        if self.use_onnx:
            input_dict = {}
            input_dict[self.input_tensor.name] = img
            outputs = self.predictor.run(self.output_tensors, input_dict)
        else:
            self.input_tensor.copy_from_cpu(img)
            self.predictor.run()
            outputs = []
            for output_tensor in self.output_tensors:
                output = output_tensor.copy_to_cpu()
                outputs.append(output)
            if self.args.benchmark:
                self.autolog.times.stamp()
LDOUBLEV's avatar
LDOUBLEV committed
232

MissPenguin's avatar
MissPenguin committed
233
234
235
236
237
238
239
240
241
        preds = {}
        if self.det_algorithm == "EAST":
            preds['f_geo'] = outputs[0]
            preds['f_score'] = outputs[1]
        elif self.det_algorithm == 'SAST':
            preds['f_border'] = outputs[0]
            preds['f_score'] = outputs[1]
            preds['f_tco'] = outputs[2]
            preds['f_tvo'] = outputs[3]
WenmuZhou's avatar
WenmuZhou committed
242
        elif self.det_algorithm in ['DB', 'PSE']:
WenmuZhou's avatar
WenmuZhou committed
243
            preds['maps'] = outputs[0]
WenmuZhou's avatar
WenmuZhou committed
244
245
246
        elif self.det_algorithm == 'FCE':
            for i, output in enumerate(outputs):
                preds['level_{}'.format(i)] = output
WenmuZhou's avatar
WenmuZhou committed
247
248
        else:
            raise NotImplementedError
LDOUBLEV's avatar
LDOUBLEV committed
249

250
        #self.predictor.try_shrink_memory()
WenmuZhou's avatar
WenmuZhou committed
251
252
        post_result = self.postprocess_op(preds, shape_list)
        dt_boxes = post_result[0]['points']
WenmuZhou's avatar
WenmuZhou committed
253
254
255
        if (self.det_algorithm == "SAST" and self.det_sast_polygon) or (
                self.det_algorithm in ["PSE", "FCE"] and
                self.postprocess_op.box_type == 'poly'):
WenmuZhou's avatar
WenmuZhou committed
256
            dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape)
MissPenguin's avatar
MissPenguin committed
257
258
        else:
            dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
LDOUBLEV's avatar
LDOUBLEV committed
259

littletomatodonkey's avatar
littletomatodonkey committed
260
        if self.args.benchmark:
Double_V's avatar
Double_V committed
261
            self.autolog.times.end(stamp=True)
LDOUBLEV's avatar
LDOUBLEV committed
262
263
        et = time.time()
        return dt_boxes, et - st
LDOUBLEV's avatar
LDOUBLEV committed
264
265
266
267


if __name__ == "__main__":
    args = utility.parse_args()
LDOUBLEV's avatar
LDOUBLEV committed
268
    image_file_list = get_image_file_list(args.image_dir)
LDOUBLEV's avatar
LDOUBLEV committed
269
270
271
    text_detector = TextDetector(args)
    count = 0
    total_time = 0
littletomatodonkey's avatar
littletomatodonkey committed
272
    draw_img_save = "./inference_results"
LDOUBLEV's avatar
LDOUBLEV committed
273

LDOUBLEV's avatar
LDOUBLEV committed
274
275
    if args.warmup:
        img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8)
276
        for i in range(2):
LDOUBLEV's avatar
LDOUBLEV committed
277
278
            res = text_detector(img)

littletomatodonkey's avatar
littletomatodonkey committed
279
280
    if not os.path.exists(draw_img_save):
        os.makedirs(draw_img_save)
LDOUBLEV's avatar
LDOUBLEV committed
281
    save_results = []
LDOUBLEV's avatar
LDOUBLEV committed
282
    for image_file in image_file_list:
LDOUBLEV's avatar
LDOUBLEV committed
283
284
285
        img, flag = check_and_read_gif(image_file)
        if not flag:
            img = cv2.imread(image_file)
LDOUBLEV's avatar
LDOUBLEV committed
286
287
288
        if img is None:
            logger.info("error in loading image:{}".format(image_file))
            continue
LDOUBLEV's avatar
LDOUBLEV committed
289
290
291
        st = time.time()
        dt_boxes, _ = text_detector(img)
        elapse = time.time() - st
LDOUBLEV's avatar
LDOUBLEV committed
292
293
294
        if count > 0:
            total_time += elapse
        count += 1
LDOUBLEV's avatar
LDOUBLEV committed
295
        save_pred = os.path.basename(image_file) + "\t" + str(
WenmuZhou's avatar
WenmuZhou committed
296
            json.dumps([x.tolist() for x in dt_boxes])) + "\n"
LDOUBLEV's avatar
LDOUBLEV committed
297
298
        save_results.append(save_pred)
        logger.info(save_pred)
LDOUBLEV's avatar
fix log  
LDOUBLEV committed
299
        logger.info("The predict time of {}: {}".format(image_file, elapse))
dyning's avatar
dyning committed
300
        src_im = utility.draw_text_det_res(dt_boxes, image_file)
WenmuZhou's avatar
WenmuZhou committed
301
        img_name_pure = os.path.split(image_file)[-1]
WenmuZhou's avatar
WenmuZhou committed
302
303
        img_path = os.path.join(draw_img_save,
                                "det_res_{}".format(img_name_pure))
LDOUBLEV's avatar
LDOUBLEV committed
304
        cv2.imwrite(img_path, src_im)
WenmuZhou's avatar
WenmuZhou committed
305
        logger.info("The visualized image saved in {}".format(img_path))
LDOUBLEV's avatar
LDOUBLEV committed
306

LDOUBLEV's avatar
LDOUBLEV committed
307
308
309
    with open(os.path.join(draw_img_save, "det_results.txt"), 'w') as f:
        f.writelines(save_results)
        f.close()
Double_V's avatar
Double_V committed
310
311
    if args.benchmark:
        text_detector.autolog.report()