predict_rec.py 16.7 KB
Newer Older
LDOUBLEV's avatar
LDOUBLEV committed
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
LDOUBLEV's avatar
LDOUBLEV committed
14
15
import os
import sys
Topdu's avatar
Topdu committed
16
from PIL import Image
17
__dir__ = os.path.dirname(os.path.abspath(__file__))
LDOUBLEV's avatar
LDOUBLEV committed
18
sys.path.append(__dir__)
19
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
LDOUBLEV's avatar
LDOUBLEV committed
20

21
22
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'

LDOUBLEV's avatar
LDOUBLEV committed
23
24
25
26
import cv2
import numpy as np
import math
import time
WenmuZhou's avatar
WenmuZhou committed
27
import traceback
tink2123's avatar
tink2123 committed
28
import paddle
29
30

import tools.infer.utility as utility
WenmuZhou's avatar
WenmuZhou committed
31
32
from ppocr.postprocess import build_post_process
from ppocr.utils.logging import get_logger
33
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
LDOUBLEV's avatar
LDOUBLEV committed
34

WenmuZhou's avatar
WenmuZhou committed
35
36
logger = get_logger()

LDOUBLEV's avatar
LDOUBLEV committed
37
38
39

class TextRecognizer(object):
    def __init__(self, args):
40
        self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")]
41
        self.rec_batch_num = args.rec_batch_num
tink2123's avatar
tink2123 committed
42
        self.rec_algorithm = args.rec_algorithm
WenmuZhou's avatar
WenmuZhou committed
43
44
        postprocess_params = {
            'name': 'CTCLabelDecode',
45
            "character_dict_path": args.rec_char_dict_path,
WenmuZhou's avatar
WenmuZhou committed
46
            "use_space_char": args.use_space_char
tink2123's avatar
tink2123 committed
47
        }
tink2123's avatar
tink2123 committed
48
49
50
        if self.rec_algorithm == "SRN":
            postprocess_params = {
                'name': 'SRNLabelDecode',
WenmuZhou's avatar
WenmuZhou committed
51
52
53
54
55
56
                "character_dict_path": args.rec_char_dict_path,
                "use_space_char": args.use_space_char
            }
        elif self.rec_algorithm == "RARE":
            postprocess_params = {
                'name': 'AttnLabelDecode',
tink2123's avatar
tink2123 committed
57
58
59
                "character_dict_path": args.rec_char_dict_path,
                "use_space_char": args.use_space_char
            }
Topdu's avatar
Topdu committed
60
61
62
63
64
65
        elif self.rec_algorithm == 'NRTR':
            postprocess_params = {
                'name': 'NRTRLabelDecode',
                "character_dict_path": args.rec_char_dict_path,
                "use_space_char": args.use_space_char
            }
Topdu's avatar
Topdu committed
66
67
68
69
70
71
        elif self.rec_algorithm == "SAR":
            postprocess_params = {
                'name': 'SARLabelDecode',
                "character_dict_path": args.rec_char_dict_path,
                "use_space_char": args.use_space_char
            }
WenmuZhou's avatar
WenmuZhou committed
72
        self.postprocess_op = build_post_process(postprocess_params)
LDOUBLEV's avatar
LDOUBLEV committed
73
        self.predictor, self.input_tensor, self.output_tensors, self.config = \
WenmuZhou's avatar
WenmuZhou committed
74
            utility.create_predictor(args, 'rec', logger)
tink2123's avatar
tink2123 committed
75
        self.benchmark = args.benchmark
tink2123's avatar
tink2123 committed
76
        self.use_onnx = args.use_onnx
tink2123's avatar
tink2123 committed
77
78
79
        if args.benchmark:
            import auto_log
            pid = os.getpid()
LDOUBLEV's avatar
LDOUBLEV committed
80
            gpu_id = utility.get_infer_gpuid()
tink2123's avatar
tink2123 committed
81
82
83
            self.autolog = auto_log.AutoLogger(
                model_name="rec",
                model_precision=args.precision,
tink2123's avatar
tink2123 committed
84
                batch_size=args.rec_batch_num,
tink2123's avatar
tink2123 committed
85
                data_shape="dynamic",
86
                save_path=None,  #args.save_log_path,
tink2123's avatar
tink2123 committed
87
88
89
                inference_config=self.config,
                pids=pid,
                process_name=None,
LDOUBLEV's avatar
LDOUBLEV committed
90
                gpu_ids=gpu_id if args.use_gpu else None,
tink2123's avatar
tink2123 committed
91
92
93
                time_keys=[
                    'preprocess_time', 'inference_time', 'postprocess_time'
                ],
94
95
                warmup=2,
                logger=logger)
LDOUBLEV's avatar
LDOUBLEV committed
96

97
    def resize_norm_img(self, img, max_wh_ratio):
LDOUBLEV's avatar
LDOUBLEV committed
98
        imgC, imgH, imgW = self.rec_image_shape
Topdu's avatar
Topdu committed
99
        if self.rec_algorithm == 'NRTR':
Topdu's avatar
Topdu committed
100
101
102
103
104
105
106
107
108
            img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
            # return padding_im
            image_pil = Image.fromarray(np.uint8(img))
            img = image_pil.resize([100, 32], Image.ANTIALIAS)
            img = np.array(img)
            norm_img = np.expand_dims(img, -1)
            norm_img = norm_img.transpose((2, 0, 1))
            return norm_img.astype(np.float32) / 128. - 1.

109
        assert imgC == img.shape[2]
tink2123's avatar
tink2123 committed
110
111
        max_wh_ratio = max(max_wh_ratio, imgW / imgH)
        imgW = int((32 * max_wh_ratio))
tink2123's avatar
tink2123 committed
112
113
        if self.use_onnx:
            imgW = 100
114
        h, w = img.shape[:2]
115
116
117
118
119
        ratio = w / float(h)
        if math.ceil(imgH * ratio) > imgW:
            resized_w = imgW
        else:
            resized_w = int(math.ceil(imgH * ratio))
tink2123's avatar
tink2123 committed
120
        resized_image = cv2.resize(img, (resized_w, imgH))
LDOUBLEV's avatar
LDOUBLEV committed
121
122
123
124
125
126
127
128
        resized_image = resized_image.astype('float32')
        resized_image = resized_image.transpose((2, 0, 1)) / 255
        resized_image -= 0.5
        resized_image /= 0.5
        padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
        padding_im[:, :, 0:resized_w] = resized_image
        return padding_im

tink2123's avatar
tink2123 committed
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
    def resize_norm_img_srn(self, img, image_shape):
        imgC, imgH, imgW = image_shape

        img_black = np.zeros((imgH, imgW))
        im_hei = img.shape[0]
        im_wid = img.shape[1]

        if im_wid <= im_hei * 1:
            img_new = cv2.resize(img, (imgH * 1, imgH))
        elif im_wid <= im_hei * 2:
            img_new = cv2.resize(img, (imgH * 2, imgH))
        elif im_wid <= im_hei * 3:
            img_new = cv2.resize(img, (imgH * 3, imgH))
        else:
            img_new = cv2.resize(img, (imgW, imgH))

        img_np = np.asarray(img_new)
        img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
        img_black[:, 0:img_np.shape[1]] = img_np
        img_black = img_black[:, :, np.newaxis]

        row, col, c = img_black.shape
        c = 1

        return np.reshape(img_black, (c, row, col)).astype(np.float32)

    def srn_other_inputs(self, image_shape, num_heads, max_text_length):

        imgC, imgH, imgW = image_shape
        feature_dim = int((imgH / 8) * (imgW / 8))

        encoder_word_pos = np.array(range(0, feature_dim)).reshape(
            (feature_dim, 1)).astype('int64')
        gsrm_word_pos = np.array(range(0, max_text_length)).reshape(
            (max_text_length, 1)).astype('int64')

        gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
        gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
            [-1, 1, max_text_length, max_text_length])
        gsrm_slf_attn_bias1 = np.tile(
            gsrm_slf_attn_bias1,
            [1, num_heads, 1, 1]).astype('float32') * [-1e9]

        gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
            [-1, 1, max_text_length, max_text_length])
        gsrm_slf_attn_bias2 = np.tile(
            gsrm_slf_attn_bias2,
            [1, num_heads, 1, 1]).astype('float32') * [-1e9]

        encoder_word_pos = encoder_word_pos[np.newaxis, :]
        gsrm_word_pos = gsrm_word_pos[np.newaxis, :]

        return [
            encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
            gsrm_slf_attn_bias2
        ]

    def process_image_srn(self, img, image_shape, num_heads, max_text_length):
        norm_img = self.resize_norm_img_srn(img, image_shape)
        norm_img = norm_img[np.newaxis, :]

        [encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
            self.srn_other_inputs(image_shape, num_heads, max_text_length)

        gsrm_slf_attn_bias1 = gsrm_slf_attn_bias1.astype(np.float32)
        gsrm_slf_attn_bias2 = gsrm_slf_attn_bias2.astype(np.float32)
        encoder_word_pos = encoder_word_pos.astype(np.int64)
        gsrm_word_pos = gsrm_word_pos.astype(np.int64)

        return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
                gsrm_slf_attn_bias2)

Topdu's avatar
Topdu committed
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
    def resize_norm_img_sar(self, img, image_shape,
                            width_downsample_ratio=0.25):
        imgC, imgH, imgW_min, imgW_max = image_shape
        h = img.shape[0]
        w = img.shape[1]
        valid_ratio = 1.0
        # make sure new_width is an integral multiple of width_divisor.
        width_divisor = int(1 / width_downsample_ratio)
        # resize
        ratio = w / float(h)
        resize_w = math.ceil(imgH * ratio)
        if resize_w % width_divisor != 0:
            resize_w = round(resize_w / width_divisor) * width_divisor
        if imgW_min is not None:
            resize_w = max(imgW_min, resize_w)
        if imgW_max is not None:
            valid_ratio = min(1.0, 1.0 * resize_w / imgW_max)
            resize_w = min(imgW_max, resize_w)
        resized_image = cv2.resize(img, (resize_w, imgH))
        resized_image = resized_image.astype('float32')
        # norm 
        if image_shape[0] == 1:
            resized_image = resized_image / 255
            resized_image = resized_image[np.newaxis, :]
        else:
            resized_image = resized_image.transpose((2, 0, 1)) / 255
        resized_image -= 0.5
        resized_image /= 0.5
        resize_shape = resized_image.shape
        padding_im = -1.0 * np.ones((imgC, imgH, imgW_max), dtype=np.float32)
        padding_im[:, :, 0:resize_w] = resized_image
        pad_shape = padding_im.shape

        return padding_im, resize_shape, pad_shape, valid_ratio

LDOUBLEV's avatar
LDOUBLEV committed
236
237
    def __call__(self, img_list):
        img_num = len(img_list)
238
        # Calculate the aspect ratio of all text bars
239
240
241
        width_list = []
        for img in img_list:
            width_list.append(img.shape[1] / float(img.shape[0]))
zhangxin's avatar
zhangxin committed
242
        # Sorting can speed up the recognition process
243
244
        indices = np.argsort(np.array(width_list))
        rec_res = [['', 0.0]] * img_num
245
        batch_num = self.rec_batch_num
LDOUBLEV's avatar
LDOUBLEV committed
246
        st = time.time()
tink2123's avatar
tink2123 committed
247
248
        if self.benchmark:
            self.autolog.times.start()
LDOUBLEV's avatar
LDOUBLEV committed
249
250
251
        for beg_img_no in range(0, img_num, batch_num):
            end_img_no = min(img_num, beg_img_no + batch_num)
            norm_img_batch = []
252
            max_wh_ratio = 0
LDOUBLEV's avatar
LDOUBLEV committed
253
            for ino in range(beg_img_no, end_img_no):
254
                h, w = img_list[indices[ino]].shape[0:2]
255
256
257
                wh_ratio = w * 1.0 / h
                max_wh_ratio = max(max_wh_ratio, wh_ratio)
            for ino in range(beg_img_no, end_img_no):
Topdu's avatar
Topdu committed
258
                if self.rec_algorithm != "SRN" and self.rec_algorithm != "SAR":
tink2123's avatar
tink2123 committed
259
260
261
262
                    norm_img = self.resize_norm_img(img_list[indices[ino]],
                                                    max_wh_ratio)
                    norm_img = norm_img[np.newaxis, :]
                    norm_img_batch.append(norm_img)
Topdu's avatar
Topdu committed
263
264
265
266
267
268
269
270
                elif self.rec_algorithm == "SAR":
                    norm_img, _, _, valid_ratio = self.resize_norm_img_sar(
                        img_list[indices[ino]], self.rec_image_shape)
                    norm_img = norm_img[np.newaxis, :]
                    valid_ratio = np.expand_dims(valid_ratio, axis=0)
                    valid_ratios = []
                    valid_ratios.append(valid_ratio)
                    norm_img_batch.append(norm_img)
tink2123's avatar
tink2123 committed
271
                else:
LDOUBLEV's avatar
LDOUBLEV committed
272
273
                    norm_img = self.process_image_srn(
                        img_list[indices[ino]], self.rec_image_shape, 8, 25)
tink2123's avatar
tink2123 committed
274
275
276
277
278
279
280
281
282
                    encoder_word_pos_list = []
                    gsrm_word_pos_list = []
                    gsrm_slf_attn_bias1_list = []
                    gsrm_slf_attn_bias2_list = []
                    encoder_word_pos_list.append(norm_img[1])
                    gsrm_word_pos_list.append(norm_img[2])
                    gsrm_slf_attn_bias1_list.append(norm_img[3])
                    gsrm_slf_attn_bias2_list.append(norm_img[4])
                    norm_img_batch.append(norm_img[0])
LDOUBLEV's avatar
LDOUBLEV committed
283
284
            norm_img_batch = np.concatenate(norm_img_batch)
            norm_img_batch = norm_img_batch.copy()
tink2123's avatar
tink2123 committed
285
286
            if self.benchmark:
                self.autolog.times.stamp()
tink2123's avatar
tink2123 committed
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302

            if self.rec_algorithm == "SRN":
                encoder_word_pos_list = np.concatenate(encoder_word_pos_list)
                gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list)
                gsrm_slf_attn_bias1_list = np.concatenate(
                    gsrm_slf_attn_bias1_list)
                gsrm_slf_attn_bias2_list = np.concatenate(
                    gsrm_slf_attn_bias2_list)

                inputs = [
                    norm_img_batch,
                    encoder_word_pos_list,
                    gsrm_word_pos_list,
                    gsrm_slf_attn_bias1_list,
                    gsrm_slf_attn_bias2_list,
                ]
tink2123's avatar
tink2123 committed
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
                if self.use_onnx:
                    input_dict = {}
                    input_dict[self.input_tensor.name] = norm_img_batch
                    outputs = self.predictor.run(self.output_tensors,
                                                 input_dict)
                    preds = {"predict": outputs[2]}
                else:
                    input_names = self.predictor.get_input_names()
                    for i in range(len(input_names)):
                        input_tensor = self.predictor.get_input_handle(
                            input_names[i])
                        input_tensor.copy_from_cpu(inputs[i])
                    self.predictor.run()
                    outputs = []
                    for output_tensor in self.output_tensors:
                        output = output_tensor.copy_to_cpu()
                        outputs.append(output)
                    if self.benchmark:
                        self.autolog.times.stamp()
                    preds = {"predict": outputs[2]}
Topdu's avatar
Topdu committed
323
324
325
326
327
328
            elif self.rec_algorithm == "SAR":
                valid_ratios = np.concatenate(valid_ratios)
                inputs = [
                    norm_img_batch,
                    valid_ratios,
                ]
tink2123's avatar
tink2123 committed
329
330
331
332
333
334
                if self.use_onnx:
                    input_dict = {}
                    input_dict[self.input_tensor.name] = norm_img_batch
                    outputs = self.predictor.run(self.output_tensors,
                                                 input_dict)
                    preds = outputs[0]
Topdu's avatar
Topdu committed
335
                else:
tink2123's avatar
tink2123 committed
336
337
338
339
340
341
342
343
344
345
346
347
                    input_names = self.predictor.get_input_names()
                    for i in range(len(input_names)):
                        input_tensor = self.predictor.get_input_handle(
                            input_names[i])
                        input_tensor.copy_from_cpu(inputs[i])
                    self.predictor.run()
                    outputs = []
                    for output_tensor in self.output_tensors:
                        output = output_tensor.copy_to_cpu()
                        outputs.append(output)
                    if self.benchmark:
                        self.autolog.times.stamp()
Topdu's avatar
Topdu committed
348
                    preds = outputs[0]
tink2123's avatar
tink2123 committed
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
            else:
                if self.use_onnx:
                    input_dict = {}
                    input_dict[self.input_tensor.name] = norm_img_batch
                    outputs = self.predictor.run(self.output_tensors,
                                                 input_dict)
                    preds = outputs[0]
                else:
                    self.input_tensor.copy_from_cpu(norm_img_batch)
                    self.predictor.run()
                    outputs = []
                    for output_tensor in self.output_tensors:
                        output = output_tensor.copy_to_cpu()
                        outputs.append(output)
                    if self.benchmark:
                        self.autolog.times.stamp()
                    if len(outputs) != 1:
                        preds = outputs
                    else:
                        preds = outputs[0]
WenmuZhou's avatar
WenmuZhou committed
369
370
371
            rec_result = self.postprocess_op(preds)
            for rno in range(len(rec_result)):
                rec_res[indices[beg_img_no + rno]] = rec_result[rno]
tink2123's avatar
tink2123 committed
372
373
            if self.benchmark:
                self.autolog.times.end(stamp=True)
LDOUBLEV's avatar
LDOUBLEV committed
374
        return rec_res, time.time() - st
LDOUBLEV's avatar
LDOUBLEV committed
375
376


377
def main(args):
dyning's avatar
dyning committed
378
    image_file_list = get_image_file_list(args.image_dir)
LDOUBLEV's avatar
LDOUBLEV committed
379
380
381
    text_recognizer = TextRecognizer(args)
    valid_image_file_list = []
    img_list = []
LDOUBLEV's avatar
LDOUBLEV committed
382

383
    # warmup 2 times
LDOUBLEV's avatar
LDOUBLEV committed
384
385
    if args.warmup:
        img = np.random.uniform(0, 255, [32, 320, 3]).astype(np.uint8)
386
        for i in range(2):
LDOUBLEV's avatar
LDOUBLEV committed
387
            res = text_recognizer([img] * int(args.rec_batch_num))
LDOUBLEV's avatar
LDOUBLEV committed
388

LDOUBLEV's avatar
LDOUBLEV committed
389
    for image_file in image_file_list:
LDOUBLEV's avatar
LDOUBLEV committed
390
391
392
        img, flag = check_and_read_gif(image_file)
        if not flag:
            img = cv2.imread(image_file)
LDOUBLEV's avatar
LDOUBLEV committed
393
394
395
396
397
        if img is None:
            logger.info("error in loading image:{}".format(image_file))
            continue
        valid_image_file_list.append(image_file)
        img_list.append(img)
LDOUBLEV's avatar
LDOUBLEV committed
398
399
400
401
402
403
404
405
406
407
    try:
        rec_res, _ = text_recognizer(img_list)

    except Exception as E:
        logger.info(traceback.format_exc())
        logger.info(E)
        exit()
    for ino in range(len(img_list)):
        logger.info("Predicts of {}:{}".format(valid_image_file_list[ino],
                                               rec_res[ino]))
tink2123's avatar
tink2123 committed
408
409
    if args.benchmark:
        text_recognizer.autolog.report()
410
411
412
413


if __name__ == "__main__":
    main(utility.parse_args())