"ppocr/vscode:/vscode.git/clone" did not exist on "453c6f68bd1f0b8470ffd3a2072d9483750808e1"
paddleocr.py 16.9 KB
Newer Older
WenmuZhou's avatar
WenmuZhou committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys

__dir__ = os.path.dirname(__file__)
sys.path.append(os.path.join(__dir__, ''))

import cv2
import numpy as np
from pathlib import Path
import tarfile
import requests
from tqdm import tqdm

from tools.infer import predict_system
WenmuZhou's avatar
WenmuZhou committed
29
from ppocr.utils.logging import get_logger
WenmuZhou's avatar
WenmuZhou committed
30

WenmuZhou's avatar
WenmuZhou committed
31
logger = get_logger()
32
from ppocr.utils.utility import check_and_read_gif, get_image_file_list
WenmuZhou's avatar
WenmuZhou committed
33
34
35

__all__ = ['PaddleOCR']

WenmuZhou's avatar
WenmuZhou committed
36
model_urls = {
tink2123's avatar
tink2123 committed
37
38
39
40
41
42
    'det': {
        'ch':
        'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar',
        'en':
        'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_ppocr_mobile_v2.0_det_infer.tar'
    },
WenmuZhou's avatar
WenmuZhou committed
43
44
45
    'rec': {
        'ch': {
            'url':
WenmuZhou's avatar
WenmuZhou committed
46
            'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar',
WenmuZhou's avatar
WenmuZhou committed
47
48
49
50
            'dict_path': './ppocr/utils/ppocr_keys_v1.txt'
        },
        'en': {
            'url':
WenmuZhou's avatar
WenmuZhou committed
51
            'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_number_mobile_v2.0_rec_infer.tar',
tink2123's avatar
tink2123 committed
52
            'dict_path': './ppocr/utils/en_dict.txt'
WenmuZhou's avatar
WenmuZhou committed
53
54
55
        },
        'french': {
            'url':
WenmuZhou's avatar
WenmuZhou committed
56
            'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/french_mobile_v2.0_rec_infer.tar',
WenmuZhou's avatar
WenmuZhou committed
57
58
59
60
            'dict_path': './ppocr/utils/dict/french_dict.txt'
        },
        'german': {
            'url':
WenmuZhou's avatar
WenmuZhou committed
61
            'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_infer.tar',
WenmuZhou's avatar
WenmuZhou committed
62
63
64
65
            'dict_path': './ppocr/utils/dict/german_dict.txt'
        },
        'korean': {
            'url':
WenmuZhou's avatar
WenmuZhou committed
66
            'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_infer.tar',
WenmuZhou's avatar
WenmuZhou committed
67
68
69
70
            'dict_path': './ppocr/utils/dict/korean_dict.txt'
        },
        'japan': {
            'url':
WenmuZhou's avatar
WenmuZhou committed
71
            'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_infer.tar',
WenmuZhou's avatar
WenmuZhou committed
72
            'dict_path': './ppocr/utils/dict/japan_dict.txt'
tink2123's avatar
tink2123 committed
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
        },
        'chinese_cht': {
            'url':
            'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/chinese_cht_mobile_v2.0_rec_infer.tar',
            'dict_path': './ppocr/utils/dict/chinese_cht_dict.txt'
        },
        'ta': {
            'url':
            'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ta_mobile_v2.0_rec_infer.tar',
            'dict_path': './ppocr/utils/dict/ta_dict.txt'
        },
        'te': {
            'url':
            'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/te_mobile_v2.0_rec_infer.tar',
            'dict_path': './ppocr/utils/dict/te_dict.txt'
        },
        'ka': {
            'url':
            'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ka_mobile_v2.0_rec_infer.tar',
            'dict_path': './ppocr/utils/dict/ka_dict.txt'
        },
        'latin': {
            'url':
            'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/latin_ppocr_mobile_v2.0_rec_infer.tar',
            'dict_path': './ppocr/utils/dict/latin_dict.txt'
        },
        'arabic': {
            'url':
            'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/arabic_ppocr_mobile_v2.0_rec_infer.tar',
            'dict_path': './ppocr/utils/dict/arabic_dict.txt'
        },
        'cyrillic': {
            'url':
            'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_infer.tar',
            'dict_path': './ppocr/utils/dict/cyrillic_dict.txt'
        },
        'devanagari': {
            'url':
            'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_infer.tar',
            'dict_path': './ppocr/utils/dict/devanagari_dict.txt'
WenmuZhou's avatar
WenmuZhou committed
113
114
115
        }
    },
    'cls':
WenmuZhou's avatar
WenmuZhou committed
116
    'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar'
WenmuZhou's avatar
WenmuZhou committed
117
118
119
}

SUPPORT_DET_MODEL = ['DB']
tink2123's avatar
tink2123 committed
120
VERSION = 2.1
121
122
SUPPORT_REC_MODEL = ['CRNN']
BASE_DIR = os.path.expanduser("~/.paddleocr/")
WenmuZhou's avatar
WenmuZhou committed
123
124
125
126
127
128
129
130
131
132
133
134


def download_with_progressbar(url, save_path):
    response = requests.get(url, stream=True)
    total_size_in_bytes = int(response.headers.get('content-length', 0))
    block_size = 1024  # 1 Kibibyte
    progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
    with open(save_path, 'wb') as file:
        for data in response.iter_content(block_size):
            progress_bar.update(len(data))
            file.write(data)
    progress_bar.close()
WenmuZhou's avatar
WenmuZhou committed
135
136
    if total_size_in_bytes == 0 or progress_bar.n != total_size_in_bytes:
        logger.error("Something went wrong while downloading models")
WenmuZhou's avatar
WenmuZhou committed
137
138
139
        sys.exit(0)


140
def maybe_download(model_storage_directory, url):
WenmuZhou's avatar
WenmuZhou committed
141
    # using custom model
WenmuZhou's avatar
WenmuZhou committed
142
143
144
145
146
147
148
    tar_file_name_list = [
        'inference.pdiparams', 'inference.pdiparams.info', 'inference.pdmodel'
    ]
    if not os.path.exists(
            os.path.join(model_storage_directory, 'inference.pdiparams')
    ) or not os.path.exists(
            os.path.join(model_storage_directory, 'inference.pdmodel')):
149
150
151
152
153
154
        tmp_path = os.path.join(model_storage_directory, url.split('/')[-1])
        print('download {} to {}'.format(url, tmp_path))
        os.makedirs(model_storage_directory, exist_ok=True)
        download_with_progressbar(url, tmp_path)
        with tarfile.open(tmp_path, 'r') as tarObj:
            for member in tarObj.getmembers():
WenmuZhou's avatar
WenmuZhou committed
155
156
157
158
159
                filename = None
                for tar_file_name in tar_file_name_list:
                    if tar_file_name in member.name:
                        filename = tar_file_name
                if filename is None:
160
161
162
163
164
165
166
                    continue
                file = tarObj.extractfile(member)
                with open(
                        os.path.join(model_storage_directory, filename),
                        'wb') as f:
                    f.write(file.read())
        os.remove(tmp_path)
WenmuZhou's avatar
WenmuZhou committed
167
168


WenmuZhou's avatar
WenmuZhou committed
169
def parse_args(mMain=True, add_help=True):
WenmuZhou's avatar
WenmuZhou committed
170
171
172
173
174
    import argparse

    def str2bool(v):
        return v.lower() in ("true", "t", "1")

WenmuZhou's avatar
WenmuZhou committed
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
    if mMain:
        parser = argparse.ArgumentParser(add_help=add_help)
        # params for prediction engine
        parser.add_argument("--use_gpu", type=str2bool, default=True)
        parser.add_argument("--ir_optim", type=str2bool, default=True)
        parser.add_argument("--use_tensorrt", type=str2bool, default=False)
        parser.add_argument("--gpu_mem", type=int, default=8000)

        # params for text detector
        parser.add_argument("--image_dir", type=str)
        parser.add_argument("--det_algorithm", type=str, default='DB')
        parser.add_argument("--det_model_dir", type=str, default=None)
        parser.add_argument("--det_limit_side_len", type=float, default=960)
        parser.add_argument("--det_limit_type", type=str, default='max')

        # DB parmas
        parser.add_argument("--det_db_thresh", type=float, default=0.3)
        parser.add_argument("--det_db_box_thresh", type=float, default=0.5)
LDOUBLEV's avatar
LDOUBLEV committed
193
194
        parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6)
        parser.add_argument("--use_dilation", type=bool, default=False)
WenmuZhou's avatar
WenmuZhou committed
195
196
197
198
199
200
201
202
203
204
205

        # EAST parmas
        parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
        parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)
        parser.add_argument("--det_east_nms_thresh", type=float, default=0.2)

        # params for text recognizer
        parser.add_argument("--rec_algorithm", type=str, default='CRNN')
        parser.add_argument("--rec_model_dir", type=str, default=None)
        parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320")
        parser.add_argument("--rec_char_type", type=str, default='ch')
tink2123's avatar
tink2123 committed
206
        parser.add_argument("--rec_batch_num", type=int, default=6)
WenmuZhou's avatar
WenmuZhou committed
207
208
209
210
211
212
213
214
215
        parser.add_argument("--max_text_length", type=int, default=25)
        parser.add_argument("--rec_char_dict_path", type=str, default=None)
        parser.add_argument("--use_space_char", type=bool, default=True)
        parser.add_argument("--drop_score", type=float, default=0.5)

        # params for text classifier
        parser.add_argument("--cls_model_dir", type=str, default=None)
        parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192")
        parser.add_argument("--label_list", type=list, default=['0', '180'])
tink2123's avatar
tink2123 committed
216
        parser.add_argument("--cls_batch_num", type=int, default=6)
WenmuZhou's avatar
WenmuZhou committed
217
218
219
220
221
222
223
224
225
226
227
228
        parser.add_argument("--cls_thresh", type=float, default=0.9)

        parser.add_argument("--enable_mkldnn", type=bool, default=False)
        parser.add_argument("--use_zero_copy_run", type=bool, default=False)
        parser.add_argument("--use_pdserving", type=str2bool, default=False)

        parser.add_argument("--lang", type=str, default='ch')
        parser.add_argument("--det", type=str2bool, default=True)
        parser.add_argument("--rec", type=str2bool, default=True)
        parser.add_argument("--use_angle_cls", type=str2bool, default=False)
        return parser.parse_args()
    else:
WenmuZhou's avatar
WenmuZhou committed
229
230
231
232
233
234
235
236
237
238
239
240
        return argparse.Namespace(
            use_gpu=True,
            ir_optim=True,
            use_tensorrt=False,
            gpu_mem=8000,
            image_dir='',
            det_algorithm='DB',
            det_model_dir=None,
            det_limit_side_len=960,
            det_limit_type='max',
            det_db_thresh=0.3,
            det_db_box_thresh=0.5,
LDOUBLEV's avatar
LDOUBLEV committed
241
242
            det_db_unclip_ratio=1.6,
            use_dilation=False,
WenmuZhou's avatar
WenmuZhou committed
243
244
245
246
247
248
249
            det_east_score_thresh=0.8,
            det_east_cover_thresh=0.1,
            det_east_nms_thresh=0.2,
            rec_algorithm='CRNN',
            rec_model_dir=None,
            rec_image_shape="3, 32, 320",
            rec_char_type='ch',
tink2123's avatar
tink2123 committed
250
            rec_batch_num=6,
WenmuZhou's avatar
WenmuZhou committed
251
252
253
254
255
256
257
            max_text_length=25,
            rec_char_dict_path=None,
            use_space_char=True,
            drop_score=0.5,
            cls_model_dir=None,
            cls_image_shape="3, 48, 192",
            label_list=['0', '180'],
tink2123's avatar
tink2123 committed
258
            cls_batch_num=6,
WenmuZhou's avatar
WenmuZhou committed
259
260
261
262
263
264
265
266
            cls_thresh=0.9,
            enable_mkldnn=False,
            use_zero_copy_run=False,
            use_pdserving=False,
            lang='ch',
            det=True,
            rec=True,
            use_angle_cls=False)
WenmuZhou's avatar
WenmuZhou committed
267
268
269


class PaddleOCR(predict_system.TextSystem):
270
    def __init__(self, **kwargs):
WenmuZhou's avatar
WenmuZhou committed
271
272
273
274
275
        """
        paddleocr package
        args:
            **kwargs: other params show in paddleocr --help
        """
WenmuZhou's avatar
WenmuZhou committed
276
        postprocess_params = parse_args(mMain=False, add_help=False)
277
        postprocess_params.__dict__.update(**kwargs)
WenmuZhou's avatar
WenmuZhou committed
278
279
        self.use_angle_cls = postprocess_params.use_angle_cls
        lang = postprocess_params.lang
tink2123's avatar
tink2123 committed
280
        latin_lang = [
tink2123's avatar
tink2123 committed
281
282
283
284
            'af', 'az', 'bs', 'cs', 'cy', 'da', 'de', 'es', 'et', 'fr', 'ga',
            'hr', 'hu', 'id', 'is', 'it', 'ku', 'la', 'lt', 'lv', 'mi', 'ms',
            'mt', 'nl', 'no', 'oc', 'pi', 'pl', 'pt', 'ro', 'rs_latin', 'sk',
            'sl', 'sq', 'sv', 'sw', 'tl', 'tr', 'uz', 'vi'
tink2123's avatar
tink2123 committed
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
        ]
        arabic_lang = ['ar', 'fa', 'ug', 'ur']
        cyrillic_lang = [
            'ru', 'rs_cyrillic', 'be', 'bg', 'uk', 'mn', 'abq', 'ady', 'kbd',
            'ava', 'dar', 'inh', 'che', 'lbe', 'lez', 'tab'
        ]
        devanagari_lang = [
            'hi', 'mr', 'ne', 'bh', 'mai', 'ang', 'bho', 'mah', 'sck', 'new',
            'gom', 'sa', 'bgc'
        ]
        if lang in latin_lang:
            lang = "latin"
        elif lang in arabic_lang:
            lang = "arabic"
        elif lang in cyrillic_lang:
            lang = "cyrillic"
        elif lang in devanagari_lang:
            lang = "devanagari"
WenmuZhou's avatar
WenmuZhou committed
303
304
        assert lang in model_urls[
            'rec'], 'param lang must in {}, but got {}'.format(
WenmuZhou's avatar
WenmuZhou committed
305
                model_urls['rec'].keys(), lang)
tink2123's avatar
tink2123 committed
306
307
308
309
        if lang == "ch":
            det_lang = "ch"
        else:
            det_lang = "en"
WenmuZhou's avatar
WenmuZhou committed
310
        use_inner_dict = False
WenmuZhou's avatar
WenmuZhou committed
311
        if postprocess_params.rec_char_dict_path is None:
WenmuZhou's avatar
WenmuZhou committed
312
            use_inner_dict = True
WenmuZhou's avatar
WenmuZhou committed
313
314
            postprocess_params.rec_char_dict_path = model_urls['rec'][lang][
                'dict_path']
WenmuZhou's avatar
WenmuZhou committed
315

316
317
        # init model dir
        if postprocess_params.det_model_dir is None:
WenmuZhou's avatar
WenmuZhou committed
318
            postprocess_params.det_model_dir = os.path.join(
tink2123's avatar
tink2123 committed
319
                BASE_DIR, '{}/det/{}'.format(VERSION, det_lang))
320
        if postprocess_params.rec_model_dir is None:
WenmuZhou's avatar
WenmuZhou committed
321
            postprocess_params.rec_model_dir = os.path.join(
WenmuZhou's avatar
WenmuZhou committed
322
                BASE_DIR, '{}/rec/{}'.format(VERSION, lang))
WenmuZhou's avatar
WenmuZhou committed
323
        if postprocess_params.cls_model_dir is None:
WenmuZhou's avatar
WenmuZhou committed
324
325
            postprocess_params.cls_model_dir = os.path.join(
                BASE_DIR, '{}/cls'.format(VERSION))
326
        print(postprocess_params)
WenmuZhou's avatar
WenmuZhou committed
327
        # download model
tink2123's avatar
tink2123 committed
328
329
        maybe_download(postprocess_params.det_model_dir,
                       model_urls['det'][det_lang])
WenmuZhou's avatar
WenmuZhou committed
330
331
332
        maybe_download(postprocess_params.rec_model_dir,
                       model_urls['rec'][lang]['url'])
        maybe_download(postprocess_params.cls_model_dir, model_urls['cls'])
WenmuZhou's avatar
WenmuZhou committed
333
334
335
336
337
338
339

        if postprocess_params.det_algorithm not in SUPPORT_DET_MODEL:
            logger.error('det_algorithm must in {}'.format(SUPPORT_DET_MODEL))
            sys.exit(0)
        if postprocess_params.rec_algorithm not in SUPPORT_REC_MODEL:
            logger.error('rec_algorithm must in {}'.format(SUPPORT_REC_MODEL))
            sys.exit(0)
WenmuZhou's avatar
WenmuZhou committed
340
341
342
        if use_inner_dict:
            postprocess_params.rec_char_dict_path = str(
                Path(__file__).parent / postprocess_params.rec_char_dict_path)
WenmuZhou's avatar
WenmuZhou committed
343
344
345
346

        # init det_model and rec_model
        super().__init__(postprocess_params)

WenmuZhou's avatar
WenmuZhou committed
347
    def ocr(self, img, det=True, rec=True, cls=False):
WenmuZhou's avatar
WenmuZhou committed
348
349
350
351
352
353
354
355
        """
        ocr with paddleocr
        args:
            img: img for ocr, support ndarray, img_path and list or ndarray
            det: use text detection or not, if false, only rec will be exec. default is True
            rec: use text recognition or not, if false, only det will be exec. default is True
        """
        assert isinstance(img, (np.ndarray, list, str))
WenmuZhou's avatar
WenmuZhou committed
356
357
358
        if isinstance(img, list) and det == True:
            logger.error('When input a list of images, det must be false')
            exit(0)
WenmuZhou's avatar
WenmuZhou committed
359
360
361
362
363
364
        if cls == False:
            self.use_angle_cls = False
        elif cls == True and self.use_angle_cls == False:
            logger.warning(
                'Since the angle classifier is not initialized, the angle classifier will not be uesd during the forward process'
            )
WenmuZhou's avatar
WenmuZhou committed
365

WenmuZhou's avatar
WenmuZhou committed
366
        if isinstance(img, str):
WenmuZhou's avatar
WenmuZhou committed
367
368
369
370
            # download net image
            if img.startswith('http'):
                download_with_progressbar(img, 'tmp.jpg')
                img = 'tmp.jpg'
WenmuZhou's avatar
WenmuZhou committed
371
372
373
            image_file = img
            img, flag = check_and_read_gif(image_file)
            if not flag:
374
375
376
                with open(image_file, 'rb') as f:
                    np_arr = np.frombuffer(f.read(), dtype=np.uint8)
                    img = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
WenmuZhou's avatar
WenmuZhou committed
377
378
379
            if img is None:
                logger.error("error in loading image:{}".format(image_file))
                return None
WenmuZhou's avatar
WenmuZhou committed
380
381
        if isinstance(img, np.ndarray) and len(img.shape) == 2:
            img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
WenmuZhou's avatar
WenmuZhou committed
382
383
384
385
386
387
388
389
390
391
392
        if det and rec:
            dt_boxes, rec_res = self.__call__(img)
            return [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
        elif det and not rec:
            dt_boxes, elapse = self.text_detector(img)
            if dt_boxes is None:
                return None
            return [box.tolist() for box in dt_boxes]
        else:
            if not isinstance(img, list):
                img = [img]
WenmuZhou's avatar
WenmuZhou committed
393
394
395
396
            if self.use_angle_cls:
                img, cls_res, elapse = self.text_classifier(img)
                if not rec:
                    return cls_res
WenmuZhou's avatar
WenmuZhou committed
397
398
            rec_res, elapse = self.text_recognizer(img)
            return rec_res
399
400
401


def main():
WenmuZhou's avatar
WenmuZhou committed
402
403
404
405
406
407
408
409
    # for cmd
    args = parse_args(mMain=True)
    image_dir = args.image_dir
    if image_dir.startswith('http'):
        download_with_progressbar(image_dir, 'tmp.jpg')
        image_file_list = ['tmp.jpg']
    else:
        image_file_list = get_image_file_list(args.image_dir)
410
411
412
    if len(image_file_list) == 0:
        logger.error('no images find in {}'.format(args.image_dir))
        return
WenmuZhou's avatar
WenmuZhou committed
413
414

    ocr_engine = PaddleOCR(**(args.__dict__))
415
    for img_path in image_file_list:
WenmuZhou's avatar
WenmuZhou committed
416
417
418
419
420
421
422
423
        logger.info('{}{}{}'.format('*' * 10, img_path, '*' * 10))
        result = ocr_engine.ocr(img_path,
                                det=args.det,
                                rec=args.rec,
                                cls=args.use_angle_cls)
        if result is not None:
            for line in result:
                logger.info(line)