predict_rec.py 10.7 KB
Newer Older
LDOUBLEV's avatar
LDOUBLEV committed
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
LDOUBLEV's avatar
LDOUBLEV committed
14
15
import os
import sys
WenmuZhou's avatar
WenmuZhou committed
16

17
__dir__ = os.path.dirname(os.path.abspath(__file__))
LDOUBLEV's avatar
LDOUBLEV committed
18
sys.path.append(__dir__)
19
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
LDOUBLEV's avatar
LDOUBLEV committed
20

21
22
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'

LDOUBLEV's avatar
LDOUBLEV committed
23
24
25
26
import cv2
import numpy as np
import math
import time
WenmuZhou's avatar
WenmuZhou committed
27
import traceback
tink2123's avatar
tink2123 committed
28
import paddle
29
30

import tools.infer.utility as utility
WenmuZhou's avatar
WenmuZhou committed
31
32
from ppocr.postprocess import build_post_process
from ppocr.utils.logging import get_logger
33
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
LDOUBLEV's avatar
LDOUBLEV committed
34

WenmuZhou's avatar
WenmuZhou committed
35
36
logger = get_logger()

LDOUBLEV's avatar
LDOUBLEV committed
37
38
39

class TextRecognizer(object):
    def __init__(self, args):
40
        self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")]
dyning's avatar
dyning committed
41
        self.character_type = args.rec_char_type
42
        self.rec_batch_num = args.rec_batch_num
tink2123's avatar
tink2123 committed
43
        self.rec_algorithm = args.rec_algorithm
WenmuZhou's avatar
WenmuZhou committed
44
45
        postprocess_params = {
            'name': 'CTCLabelDecode',
tink2123's avatar
tink2123 committed
46
            "character_type": args.rec_char_type,
47
            "character_dict_path": args.rec_char_dict_path,
WenmuZhou's avatar
WenmuZhou committed
48
            "use_space_char": args.use_space_char
tink2123's avatar
tink2123 committed
49
        }
tink2123's avatar
tink2123 committed
50
51
52
53
54
55
56
        if self.rec_algorithm == "SRN":
            postprocess_params = {
                'name': 'SRNLabelDecode',
                "character_type": args.rec_char_type,
                "character_dict_path": args.rec_char_dict_path,
                "use_space_char": args.use_space_char
            }
WenmuZhou's avatar
WenmuZhou committed
57
58
59
        self.postprocess_op = build_post_process(postprocess_params)
        self.predictor, self.input_tensor, self.output_tensors = \
            utility.create_predictor(args, 'rec', logger)
LDOUBLEV's avatar
LDOUBLEV committed
60

61
    def resize_norm_img(self, img, max_wh_ratio):
LDOUBLEV's avatar
LDOUBLEV committed
62
        imgC, imgH, imgW = self.rec_image_shape
63
        assert imgC == img.shape[2]
64
        if self.character_type == "ch":
tink2123's avatar
tink2123 committed
65
            imgW = int((32 * max_wh_ratio))
66
        h, w = img.shape[:2]
67
68
69
70
71
        ratio = w / float(h)
        if math.ceil(imgH * ratio) > imgW:
            resized_w = imgW
        else:
            resized_w = int(math.ceil(imgH * ratio))
tink2123's avatar
tink2123 committed
72
        resized_image = cv2.resize(img, (resized_w, imgH))
LDOUBLEV's avatar
LDOUBLEV committed
73
74
75
76
77
78
79
80
        resized_image = resized_image.astype('float32')
        resized_image = resized_image.transpose((2, 0, 1)) / 255
        resized_image -= 0.5
        resized_image /= 0.5
        padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
        padding_im[:, :, 0:resized_w] = resized_image
        return padding_im

tink2123's avatar
tink2123 committed
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
    def resize_norm_img_srn(self, img, image_shape):
        imgC, imgH, imgW = image_shape

        img_black = np.zeros((imgH, imgW))
        im_hei = img.shape[0]
        im_wid = img.shape[1]

        if im_wid <= im_hei * 1:
            img_new = cv2.resize(img, (imgH * 1, imgH))
        elif im_wid <= im_hei * 2:
            img_new = cv2.resize(img, (imgH * 2, imgH))
        elif im_wid <= im_hei * 3:
            img_new = cv2.resize(img, (imgH * 3, imgH))
        else:
            img_new = cv2.resize(img, (imgW, imgH))

        img_np = np.asarray(img_new)
        img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
        img_black[:, 0:img_np.shape[1]] = img_np
        img_black = img_black[:, :, np.newaxis]

        row, col, c = img_black.shape
        c = 1

        return np.reshape(img_black, (c, row, col)).astype(np.float32)

    def srn_other_inputs(self, image_shape, num_heads, max_text_length):

        imgC, imgH, imgW = image_shape
        feature_dim = int((imgH / 8) * (imgW / 8))

        encoder_word_pos = np.array(range(0, feature_dim)).reshape(
            (feature_dim, 1)).astype('int64')
        gsrm_word_pos = np.array(range(0, max_text_length)).reshape(
            (max_text_length, 1)).astype('int64')

        gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
        gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
            [-1, 1, max_text_length, max_text_length])
        gsrm_slf_attn_bias1 = np.tile(
            gsrm_slf_attn_bias1,
            [1, num_heads, 1, 1]).astype('float32') * [-1e9]

        gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
            [-1, 1, max_text_length, max_text_length])
        gsrm_slf_attn_bias2 = np.tile(
            gsrm_slf_attn_bias2,
            [1, num_heads, 1, 1]).astype('float32') * [-1e9]

        encoder_word_pos = encoder_word_pos[np.newaxis, :]
        gsrm_word_pos = gsrm_word_pos[np.newaxis, :]

        return [
            encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
            gsrm_slf_attn_bias2
        ]

    def process_image_srn(self, img, image_shape, num_heads, max_text_length):
        norm_img = self.resize_norm_img_srn(img, image_shape)
        norm_img = norm_img[np.newaxis, :]

        [encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
            self.srn_other_inputs(image_shape, num_heads, max_text_length)

        gsrm_slf_attn_bias1 = gsrm_slf_attn_bias1.astype(np.float32)
        gsrm_slf_attn_bias2 = gsrm_slf_attn_bias2.astype(np.float32)
        encoder_word_pos = encoder_word_pos.astype(np.int64)
        gsrm_word_pos = gsrm_word_pos.astype(np.int64)

        return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
                gsrm_slf_attn_bias2)

LDOUBLEV's avatar
LDOUBLEV committed
153
154
    def __call__(self, img_list):
        img_num = len(img_list)
155
        # Calculate the aspect ratio of all text bars
156
157
158
        width_list = []
        for img in img_list:
            width_list.append(img.shape[1] / float(img.shape[0]))
zhangxin's avatar
zhangxin committed
159
        # Sorting can speed up the recognition process
160
161
162
163
        indices = np.argsort(np.array(width_list))

        # rec_res = []
        rec_res = [['', 0.0]] * img_num
164
        batch_num = self.rec_batch_num
WenmuZhou's avatar
WenmuZhou committed
165
        elapse = 0
LDOUBLEV's avatar
LDOUBLEV committed
166
167
168
        for beg_img_no in range(0, img_num, batch_num):
            end_img_no = min(img_num, beg_img_no + batch_num)
            norm_img_batch = []
169
            max_wh_ratio = 0
LDOUBLEV's avatar
LDOUBLEV committed
170
            for ino in range(beg_img_no, end_img_no):
171
172
                # h, w = img_list[ino].shape[0:2]
                h, w = img_list[indices[ino]].shape[0:2]
173
174
175
                wh_ratio = w * 1.0 / h
                max_wh_ratio = max(max_wh_ratio, wh_ratio)
            for ino in range(beg_img_no, end_img_no):
tink2123's avatar
tink2123 committed
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
                if self.rec_algorithm != "SRN":
                    norm_img = self.resize_norm_img(img_list[indices[ino]],
                                                    max_wh_ratio)
                    norm_img = norm_img[np.newaxis, :]
                    norm_img_batch.append(norm_img)
                else:
                    norm_img = self.process_image_srn(
                        img_list[indices[ino]], self.rec_image_shape, 8, 25)
                    encoder_word_pos_list = []
                    gsrm_word_pos_list = []
                    gsrm_slf_attn_bias1_list = []
                    gsrm_slf_attn_bias2_list = []
                    encoder_word_pos_list.append(norm_img[1])
                    gsrm_word_pos_list.append(norm_img[2])
                    gsrm_slf_attn_bias1_list.append(norm_img[3])
                    gsrm_slf_attn_bias2_list.append(norm_img[4])
                    norm_img_batch.append(norm_img[0])
LDOUBLEV's avatar
LDOUBLEV committed
193
194
            norm_img_batch = np.concatenate(norm_img_batch)
            norm_img_batch = norm_img_batch.copy()
tink2123's avatar
tink2123 committed
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233

            if self.rec_algorithm == "SRN":
                starttime = time.time()
                encoder_word_pos_list = np.concatenate(encoder_word_pos_list)
                gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list)
                gsrm_slf_attn_bias1_list = np.concatenate(
                    gsrm_slf_attn_bias1_list)
                gsrm_slf_attn_bias2_list = np.concatenate(
                    gsrm_slf_attn_bias2_list)

                inputs = [
                    norm_img_batch,
                    encoder_word_pos_list,
                    gsrm_word_pos_list,
                    gsrm_slf_attn_bias1_list,
                    gsrm_slf_attn_bias2_list,
                ]
                input_names = self.predictor.get_input_names()
                for i in range(len(input_names)):
                    input_tensor = self.predictor.get_input_handle(input_names[
                        i])
                    input_tensor.copy_from_cpu(inputs[i])
                self.predictor.run()
                outputs = []
                for output_tensor in self.output_tensors:
                    output = output_tensor.copy_to_cpu()
                    outputs.append(output)
                preds = {"predict": outputs[2]}
            else:
                starttime = time.time()
                self.input_tensor.copy_from_cpu(norm_img_batch)
                self.predictor.run()

                outputs = []
                for output_tensor in self.output_tensors:
                    output = output_tensor.copy_to_cpu()
                    outputs.append(output)
                preds = outputs[0]

WenmuZhou's avatar
WenmuZhou committed
234
235
236
            rec_result = self.postprocess_op(preds)
            for rno in range(len(rec_result)):
                rec_res[indices[beg_img_no + rno]] = rec_result[rno]
237
            elapse += time.time() - starttime
WenmuZhou's avatar
WenmuZhou committed
238
        return rec_res, elapse
LDOUBLEV's avatar
LDOUBLEV committed
239
240


241
def main(args):
dyning's avatar
dyning committed
242
    image_file_list = get_image_file_list(args.image_dir)
LDOUBLEV's avatar
LDOUBLEV committed
243
244
245
246
    text_recognizer = TextRecognizer(args)
    valid_image_file_list = []
    img_list = []
    for image_file in image_file_list:
LDOUBLEV's avatar
LDOUBLEV committed
247
248
249
        img, flag = check_and_read_gif(image_file)
        if not flag:
            img = cv2.imread(image_file)
LDOUBLEV's avatar
LDOUBLEV committed
250
251
252
253
254
        if img is None:
            logger.info("error in loading image:{}".format(image_file))
            continue
        valid_image_file_list.append(image_file)
        img_list.append(img)
tink2123's avatar
tink2123 committed
255
256
    try:
        rec_res, predict_time = text_recognizer(img_list)
WenmuZhou's avatar
WenmuZhou committed
257
258
    except:
        logger.info(traceback.format_exc())
tink2123's avatar
tink2123 committed
259
        logger.info(
tink2123's avatar
tink2123 committed
260
261
262
263
            "ERROR!!!! \n"
            "Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n"
            "If your model has tps module:  "
            "TPS does not support variable shape.\n"
tink2123's avatar
tink2123 committed
264
            "Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ")
tink2123's avatar
tink2123 committed
265
        exit()
LDOUBLEV's avatar
LDOUBLEV committed
266
    for ino in range(len(img_list)):
WenmuZhou's avatar
WenmuZhou committed
267
268
        logger.info("Predicts of {}:{}".format(valid_image_file_list[ino],
                                               rec_res[ino]))
WenmuZhou's avatar
WenmuZhou committed
269
    logger.info("Total predict time for {} images, cost: {:.3f}".format(
WenmuZhou's avatar
WenmuZhou committed
270
        len(img_list), predict_time))
271
272
273
274


if __name__ == "__main__":
    main(utility.parse_args())