predict_det.py 11.6 KB
Newer Older
LDOUBLEV's avatar
LDOUBLEV committed
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
LDOUBLEV's avatar
LDOUBLEV committed
14
15
import os
import sys
WenmuZhou's avatar
WenmuZhou committed
16

17
__dir__ = os.path.dirname(os.path.abspath(__file__))
LDOUBLEV's avatar
LDOUBLEV committed
18
sys.path.append(__dir__)
littletomatodonkey's avatar
littletomatodonkey committed
19
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
LDOUBLEV's avatar
LDOUBLEV committed
20

21
22
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'

23
24
25
26
import cv2
import numpy as np
import time
import sys
LDOUBLEV's avatar
fix  
LDOUBLEV committed
27
from scipy.spatial import distance as dist
28

LDOUBLEV's avatar
LDOUBLEV committed
29
import tools.infer.utility as utility
WenmuZhou's avatar
WenmuZhou committed
30
from ppocr.utils.logging import get_logger
LDOUBLEV's avatar
LDOUBLEV committed
31
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
WenmuZhou's avatar
WenmuZhou committed
32
33
from ppocr.data import create_operators, transform
from ppocr.postprocess import build_post_process
LDOUBLEV's avatar
LDOUBLEV committed
34
import json
WenmuZhou's avatar
WenmuZhou committed
35
36
logger = get_logger()

LDOUBLEV's avatar
LDOUBLEV committed
37
38
39

class TextDetector(object):
    def __init__(self, args):
LDOUBLEV's avatar
LDOUBLEV committed
40
        self.args = args
LDOUBLEV's avatar
LDOUBLEV committed
41
        self.det_algorithm = args.det_algorithm
tink2123's avatar
tink2123 committed
42
        self.use_onnx = args.use_onnx
MissPenguin's avatar
MissPenguin committed
43
        pre_process_list = [{
44
45
            'DetResizeForTest': {
                'limit_side_len': args.det_limit_side_len,
WenmuZhou's avatar
WenmuZhou committed
46
                'limit_type': args.det_limit_type,
47
            }
MissPenguin's avatar
MissPenguin committed
48
49
50
51
52
53
54
55
56
57
58
59
60
61
        }, {
            'NormalizeImage': {
                'std': [0.229, 0.224, 0.225],
                'mean': [0.485, 0.456, 0.406],
                'scale': '1./255.',
                'order': 'hwc'
            }
        }, {
            'ToCHWImage': None
        }, {
            'KeepKeys': {
                'keep_keys': ['image', 'shape']
            }
        }]
LDOUBLEV's avatar
LDOUBLEV committed
62
63
        postprocess_params = {}
        if self.det_algorithm == "DB":
WenmuZhou's avatar
WenmuZhou committed
64
            postprocess_params['name'] = 'DBPostProcess'
LDOUBLEV's avatar
LDOUBLEV committed
65
66
67
            postprocess_params["thresh"] = args.det_db_thresh
            postprocess_params["box_thresh"] = args.det_db_box_thresh
            postprocess_params["max_candidates"] = 1000
68
            postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
LDOUBLEV's avatar
LDOUBLEV committed
69
            postprocess_params["use_dilation"] = args.use_dilation
littletomatodonkey's avatar
littletomatodonkey committed
70
            postprocess_params["score_mode"] = args.det_db_score_mode
MissPenguin's avatar
MissPenguin committed
71
        elif self.det_algorithm == "EAST":
WenmuZhou's avatar
WenmuZhou committed
72
            postprocess_params['name'] = 'EASTPostProcess'
MissPenguin's avatar
MissPenguin committed
73
74
75
76
            postprocess_params["score_thresh"] = args.det_east_score_thresh
            postprocess_params["cover_thresh"] = args.det_east_cover_thresh
            postprocess_params["nms_thresh"] = args.det_east_nms_thresh
        elif self.det_algorithm == "SAST":
MissPenguin's avatar
MissPenguin committed
77
            pre_process_list[0] = {
WenmuZhou's avatar
WenmuZhou committed
78
79
80
                'DetResizeForTest': {
                    'resize_long': args.det_limit_side_len
                }
MissPenguin's avatar
MissPenguin committed
81
            }
WenmuZhou's avatar
WenmuZhou committed
82
            postprocess_params['name'] = 'SASTPostProcess'
MissPenguin's avatar
MissPenguin committed
83
84
85
86
87
88
89
90
91
92
93
            postprocess_params["score_thresh"] = args.det_sast_score_thresh
            postprocess_params["nms_thresh"] = args.det_sast_nms_thresh
            self.det_sast_polygon = args.det_sast_polygon
            if self.det_sast_polygon:
                postprocess_params["sample_pts_num"] = 6
                postprocess_params["expand_scale"] = 1.2
                postprocess_params["shrink_ratio_of_width"] = 0.2
            else:
                postprocess_params["sample_pts_num"] = 2
                postprocess_params["expand_scale"] = 1.0
                postprocess_params["shrink_ratio_of_width"] = 0.3
WenmuZhou's avatar
WenmuZhou committed
94
95
96
97
98
99
100
101
        elif self.det_algorithm == "PSE":
            postprocess_params['name'] = 'PSEPostProcess'
            postprocess_params["thresh"] = args.det_pse_thresh
            postprocess_params["box_thresh"] = args.det_pse_box_thresh
            postprocess_params["min_area"] = args.det_pse_min_area
            postprocess_params["box_type"] = args.det_pse_box_type
            postprocess_params["scale"] = args.det_pse_scale
            self.det_pse_box_type = args.det_pse_box_type
WenmuZhou's avatar
WenmuZhou committed
102
103
104
105
106
107
108
109
110
111
112
113
        elif self.det_algorithm == "FCE":
            pre_process_list[0] = {
                'DetResizeForTest': {
                    'rescale_img': [1080, 736]
                }
            }
            postprocess_params['name'] = 'FCEPostProcess'
            postprocess_params["scales"] = args.scales
            postprocess_params["alpha"] = args.alpha
            postprocess_params["beta"] = args.beta
            postprocess_params["fourier_degree"] = args.fourier_degree
            postprocess_params["box_type"] = args.det_fce_box_type
LDOUBLEV's avatar
LDOUBLEV committed
114
115
116
        else:
            logger.info("unknown det_algorithm:{}".format(self.det_algorithm))
            sys.exit(0)
117

WenmuZhou's avatar
WenmuZhou committed
118
119
        self.preprocess_op = create_operators(pre_process_list)
        self.postprocess_op = build_post_process(postprocess_params)
LDOUBLEV's avatar
LDOUBLEV committed
120
121
122
        self.predictor, self.input_tensor, self.output_tensors, self.config = utility.create_predictor(
            args, 'det', logger)

123
124
125
126
127
128
129
130
131
132
        if self.use_onnx:
            img_h, img_w = self.input_tensor.shape[2:]
            if img_h is not None and img_w is not None and img_h > 0 and img_w > 0:
                pre_process_list[0] = {
                    'DetResizeForTest': {
                        'image_shape': [img_h, img_w]
                    }
                }
        self.preprocess_op = create_operators(pre_process_list)

Double_V's avatar
Double_V committed
133
        if args.benchmark:
Double_V's avatar
Double_V committed
134
            import auto_log
Double_V's avatar
Double_V committed
135
            pid = os.getpid()
LDOUBLEV's avatar
LDOUBLEV committed
136
            gpu_id = utility.get_infer_gpuid()
Double_V's avatar
Double_V committed
137
138
139
140
141
            self.autolog = auto_log.AutoLogger(
                model_name="det",
                model_precision=args.precision,
                batch_size=1,
                data_shape="dynamic",
LDOUBLEV's avatar
LDOUBLEV committed
142
                save_path=None,
Double_V's avatar
Double_V committed
143
144
145
                inference_config=self.config,
                pids=pid,
                process_name=None,
146
                gpu_ids=gpu_id if args.use_gpu else None,
Double_V's avatar
Double_V committed
147
148
149
                time_keys=[
                    'preprocess_time', 'inference_time', 'postprocess_time'
                ],
150
                warmup=2,
LDOUBLEV's avatar
LDOUBLEV committed
151
                logger=logger)
LDOUBLEV's avatar
LDOUBLEV committed
152

LDOUBLEV's avatar
LDOUBLEV committed
153
    def order_points_clockwise(self, pts):
LDOUBLEV's avatar
fix  
LDOUBLEV committed
154
155
156
157
158
159
160
161
162
        rect = np.zeros((4, 2), dtype="float32")
        s = pts.sum(axis=1)
        rect[0] = pts[np.argmin(s)]
        rect[2] = pts[np.argmax(s)]
        tmp = np.delete(pts, (np.argmin(s), np.argmax(s)), axis=0)
        diff = np.diff(np.array(tmp), axis=1)
        rect[1] = tmp[np.argmin(diff)]
        rect[3] = tmp[np.argmax(diff)]
        return rect
WenmuZhou's avatar
WenmuZhou committed
163

dyning's avatar
dyning committed
164
    def clip_det_res(self, points, img_height, img_width):
165
        for pno in range(points.shape[0]):
dyning's avatar
dyning committed
166
167
            points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
            points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
LDOUBLEV's avatar
LDOUBLEV committed
168
169
170
171
172
173
174
        return points

    def filter_tag_det_res(self, dt_boxes, image_shape):
        img_height, img_width = image_shape[0:2]
        dt_boxes_new = []
        for box in dt_boxes:
            box = self.order_points_clockwise(box)
dyning's avatar
dyning committed
175
            box = self.clip_det_res(box, img_height, img_width)
LDOUBLEV's avatar
LDOUBLEV committed
176
177
            rect_width = int(np.linalg.norm(box[0] - box[1]))
            rect_height = int(np.linalg.norm(box[0] - box[3]))
MissPenguin's avatar
MissPenguin committed
178
            if rect_width <= 3 or rect_height <= 3:
LDOUBLEV's avatar
LDOUBLEV committed
179
180
181
182
183
                continue
            dt_boxes_new.append(box)
        dt_boxes = np.array(dt_boxes_new)
        return dt_boxes

184
185
186
187
188
189
190
191
    def filter_tag_det_res_only_clip(self, dt_boxes, image_shape):
        img_height, img_width = image_shape[0:2]
        dt_boxes_new = []
        for box in dt_boxes:
            box = self.clip_det_res(box, img_height, img_width)
            dt_boxes_new.append(box)
        dt_boxes = np.array(dt_boxes_new)
        return dt_boxes
192

LDOUBLEV's avatar
LDOUBLEV committed
193
194
    def __call__(self, img):
        ori_im = img.copy()
WenmuZhou's avatar
WenmuZhou committed
195
        data = {'image': img}
LDOUBLEV's avatar
LDOUBLEV committed
196
197

        st = time.time()
LDOUBLEV's avatar
LDOUBLEV committed
198

littletomatodonkey's avatar
littletomatodonkey committed
199
        if self.args.benchmark:
Double_V's avatar
Double_V committed
200
            self.autolog.times.start()
LDOUBLEV's avatar
LDOUBLEV committed
201

WenmuZhou's avatar
WenmuZhou committed
202
203
204
        data = transform(data, self.preprocess_op)
        img, shape_list = data
        if img is None:
LDOUBLEV's avatar
LDOUBLEV committed
205
            return None, 0
WenmuZhou's avatar
WenmuZhou committed
206
207
        img = np.expand_dims(img, axis=0)
        shape_list = np.expand_dims(shape_list, axis=0)
208
        img = img.copy()
LDOUBLEV's avatar
LDOUBLEV committed
209

littletomatodonkey's avatar
littletomatodonkey committed
210
        if self.args.benchmark:
Double_V's avatar
Double_V committed
211
            self.autolog.times.stamp()
tink2123's avatar
tink2123 committed
212
213
214
215
216
217
218
219
220
221
222
223
224
        if self.use_onnx:
            input_dict = {}
            input_dict[self.input_tensor.name] = img
            outputs = self.predictor.run(self.output_tensors, input_dict)
        else:
            self.input_tensor.copy_from_cpu(img)
            self.predictor.run()
            outputs = []
            for output_tensor in self.output_tensors:
                output = output_tensor.copy_to_cpu()
                outputs.append(output)
            if self.args.benchmark:
                self.autolog.times.stamp()
LDOUBLEV's avatar
LDOUBLEV committed
225

MissPenguin's avatar
MissPenguin committed
226
227
228
229
230
231
232
233
234
        preds = {}
        if self.det_algorithm == "EAST":
            preds['f_geo'] = outputs[0]
            preds['f_score'] = outputs[1]
        elif self.det_algorithm == 'SAST':
            preds['f_border'] = outputs[0]
            preds['f_score'] = outputs[1]
            preds['f_tco'] = outputs[2]
            preds['f_tvo'] = outputs[3]
WenmuZhou's avatar
WenmuZhou committed
235
        elif self.det_algorithm in ['DB', 'PSE']:
WenmuZhou's avatar
WenmuZhou committed
236
            preds['maps'] = outputs[0]
WenmuZhou's avatar
WenmuZhou committed
237
238
239
        elif self.det_algorithm == 'FCE':
            for i, output in enumerate(outputs):
                preds['level_{}'.format(i)] = output
WenmuZhou's avatar
WenmuZhou committed
240
241
        else:
            raise NotImplementedError
LDOUBLEV's avatar
LDOUBLEV committed
242

243
        #self.predictor.try_shrink_memory()
WenmuZhou's avatar
WenmuZhou committed
244
245
        post_result = self.postprocess_op(preds, shape_list)
        dt_boxes = post_result[0]['points']
WenmuZhou's avatar
WenmuZhou committed
246
247
248
        if (self.det_algorithm == "SAST" and self.det_sast_polygon) or (
                self.det_algorithm in ["PSE", "FCE"] and
                self.postprocess_op.box_type == 'poly'):
WenmuZhou's avatar
WenmuZhou committed
249
            dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape)
MissPenguin's avatar
MissPenguin committed
250
251
        else:
            dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
LDOUBLEV's avatar
LDOUBLEV committed
252

littletomatodonkey's avatar
littletomatodonkey committed
253
        if self.args.benchmark:
Double_V's avatar
Double_V committed
254
            self.autolog.times.end(stamp=True)
LDOUBLEV's avatar
LDOUBLEV committed
255
256
        et = time.time()
        return dt_boxes, et - st
LDOUBLEV's avatar
LDOUBLEV committed
257
258
259
260


if __name__ == "__main__":
    args = utility.parse_args()
LDOUBLEV's avatar
LDOUBLEV committed
261
    image_file_list = get_image_file_list(args.image_dir)
LDOUBLEV's avatar
LDOUBLEV committed
262
263
264
    text_detector = TextDetector(args)
    count = 0
    total_time = 0
littletomatodonkey's avatar
littletomatodonkey committed
265
    draw_img_save = "./inference_results"
LDOUBLEV's avatar
LDOUBLEV committed
266

LDOUBLEV's avatar
LDOUBLEV committed
267
268
    if args.warmup:
        img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8)
269
        for i in range(2):
LDOUBLEV's avatar
LDOUBLEV committed
270
271
            res = text_detector(img)

littletomatodonkey's avatar
littletomatodonkey committed
272
273
    if not os.path.exists(draw_img_save):
        os.makedirs(draw_img_save)
LDOUBLEV's avatar
LDOUBLEV committed
274
    save_results = []
LDOUBLEV's avatar
LDOUBLEV committed
275
    for image_file in image_file_list:
LDOUBLEV's avatar
LDOUBLEV committed
276
277
278
        img, flag = check_and_read_gif(image_file)
        if not flag:
            img = cv2.imread(image_file)
LDOUBLEV's avatar
LDOUBLEV committed
279
280
281
        if img is None:
            logger.info("error in loading image:{}".format(image_file))
            continue
LDOUBLEV's avatar
LDOUBLEV committed
282
283
284
        st = time.time()
        dt_boxes, _ = text_detector(img)
        elapse = time.time() - st
LDOUBLEV's avatar
LDOUBLEV committed
285
286
287
        if count > 0:
            total_time += elapse
        count += 1
LDOUBLEV's avatar
LDOUBLEV committed
288
        save_pred = os.path.basename(image_file) + "\t" + str(
WenmuZhou's avatar
WenmuZhou committed
289
            json.dumps([x.tolist() for x in dt_boxes])) + "\n"
LDOUBLEV's avatar
LDOUBLEV committed
290
291
        save_results.append(save_pred)
        logger.info(save_pred)
LDOUBLEV's avatar
fix log  
LDOUBLEV committed
292
        logger.info("The predict time of {}: {}".format(image_file, elapse))
dyning's avatar
dyning committed
293
        src_im = utility.draw_text_det_res(dt_boxes, image_file)
WenmuZhou's avatar
WenmuZhou committed
294
        img_name_pure = os.path.split(image_file)[-1]
WenmuZhou's avatar
WenmuZhou committed
295
296
        img_path = os.path.join(draw_img_save,
                                "det_res_{}".format(img_name_pure))
LDOUBLEV's avatar
LDOUBLEV committed
297
        cv2.imwrite(img_path, src_im)
WenmuZhou's avatar
WenmuZhou committed
298
        logger.info("The visualized image saved in {}".format(img_path))
LDOUBLEV's avatar
LDOUBLEV committed
299

LDOUBLEV's avatar
LDOUBLEV committed
300
301
302
    with open(os.path.join(draw_img_save, "det_results.txt"), 'w') as f:
        f.writelines(save_results)
        f.close()
Double_V's avatar
Double_V committed
303
304
    if args.benchmark:
        text_detector.autolog.report()