paddleocr.py 7.98 KB
Newer Older
WenmuZhou's avatar
WenmuZhou committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys

__dir__ = os.path.dirname(__file__)
sys.path.append(os.path.join(__dir__, ''))

import cv2
import numpy as np
from pathlib import Path
import tarfile
import requests
from tqdm import tqdm

from tools.infer import predict_system
from ppocr.utils.utility import initial_logger

logger = initial_logger()
32
from ppocr.utils.utility import check_and_read_gif, get_image_file_list
WenmuZhou's avatar
WenmuZhou committed
33
34
35
36

__all__ = ['PaddleOCR']

model_params = {
37
38
39
    'det': 'https://paddleocr.bj.bcebos.com/ch_models/ch_det_mv3_db_infer.tar',
    'rec':
    'https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn_enhance_infer.tar',
WenmuZhou's avatar
WenmuZhou committed
40
41
42
}

SUPPORT_DET_MODEL = ['DB']
43
44
SUPPORT_REC_MODEL = ['CRNN']
BASE_DIR = os.path.expanduser("~/.paddleocr/")
WenmuZhou's avatar
WenmuZhou committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61


def download_with_progressbar(url, save_path):
    response = requests.get(url, stream=True)
    total_size_in_bytes = int(response.headers.get('content-length', 0))
    block_size = 1024  # 1 Kibibyte
    progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
    with open(save_path, 'wb') as file:
        for data in response.iter_content(block_size):
            progress_bar.update(len(data))
            file.write(data)
    progress_bar.close()
    if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
        logger.error("ERROR, something went wrong")
        sys.exit(0)


62
def maybe_download(model_storage_directory, url):
WenmuZhou's avatar
WenmuZhou committed
63
    # using custom model
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
    if not os.path.exists(os.path.join(
            model_storage_directory, 'model')) or not os.path.exists(
                os.path.join(model_storage_directory, 'params')):
        tmp_path = os.path.join(model_storage_directory, url.split('/')[-1])
        print('download {} to {}'.format(url, tmp_path))
        os.makedirs(model_storage_directory, exist_ok=True)
        download_with_progressbar(url, tmp_path)
        with tarfile.open(tmp_path, 'r') as tarObj:
            for member in tarObj.getmembers():
                if "model" in member.name:
                    filename = 'model'
                elif "params" in member.name:
                    filename = 'params'
                else:
                    continue
                file = tarObj.extractfile(member)
                with open(
                        os.path.join(model_storage_directory, filename),
                        'wb') as f:
                    f.write(file.read())
        os.remove(tmp_path)
WenmuZhou's avatar
WenmuZhou committed
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102


def parse_args():
    import argparse

    def str2bool(v):
        return v.lower() in ("true", "t", "1")

    parser = argparse.ArgumentParser()
    # params for prediction engine
    parser.add_argument("--use_gpu", type=str2bool, default=True)
    parser.add_argument("--ir_optim", type=str2bool, default=True)
    parser.add_argument("--use_tensorrt", type=str2bool, default=False)
    parser.add_argument("--gpu_mem", type=int, default=8000)

    # params for text detector
    parser.add_argument("--image_dir", type=str)
    parser.add_argument("--det_algorithm", type=str, default='DB')
103
    parser.add_argument("--det_model_dir", type=str, default=None)
WenmuZhou's avatar
WenmuZhou committed
104
105
106
107
108
109
110
111
112
113
114
115
116
117
    parser.add_argument("--det_max_side_len", type=float, default=960)

    # DB parmas
    parser.add_argument("--det_db_thresh", type=float, default=0.3)
    parser.add_argument("--det_db_box_thresh", type=float, default=0.5)
    parser.add_argument("--det_db_unclip_ratio", type=float, default=2.0)

    # EAST parmas
    parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
    parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)
    parser.add_argument("--det_east_nms_thresh", type=float, default=0.2)

    # params for text recognizer
    parser.add_argument("--rec_algorithm", type=str, default='CRNN')
118
    parser.add_argument("--rec_model_dir", type=str, default=None)
WenmuZhou's avatar
WenmuZhou committed
119
120
121
    parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320")
    parser.add_argument("--rec_char_type", type=str, default='ch')
    parser.add_argument("--rec_batch_num", type=int, default=30)
122
    parser.add_argument("--max_text_length", type=int, default=25)
WenmuZhou's avatar
WenmuZhou committed
123
124
125
126
127
128
129
130
131
132
133
134
135
    parser.add_argument(
        "--rec_char_dict_path",
        type=str,
        default="./ppocr/utils/ppocr_keys_v1.txt")
    parser.add_argument("--use_space_char", type=bool, default=True)
    parser.add_argument("--enable_mkldnn", type=bool, default=False)

    parser.add_argument("--det", type=str2bool, default=True)
    parser.add_argument("--rec", type=str2bool, default=True)
    return parser.parse_args()


class PaddleOCR(predict_system.TextSystem):
136
    def __init__(self, **kwargs):
WenmuZhou's avatar
WenmuZhou committed
137
138
139
140
141
142
        """
        paddleocr package
        args:
            **kwargs: other params show in paddleocr --help
        """
        postprocess_params = parse_args()
143
        postprocess_params.__dict__.update(**kwargs)
WenmuZhou's avatar
WenmuZhou committed
144

145
146
147
148
149
150
        # init model dir
        if postprocess_params.det_model_dir is None:
            postprocess_params.det_model_dir = os.path.join(BASE_DIR, 'det')
        if postprocess_params.rec_model_dir is None:
            postprocess_params.rec_model_dir = os.path.join(BASE_DIR, 'rec')
        print(postprocess_params)
WenmuZhou's avatar
WenmuZhou committed
151
        # download model
152
153
        maybe_download(postprocess_params.det_model_dir, model_params['det'])
        maybe_download(postprocess_params.rec_model_dir, model_params['rec'])
WenmuZhou's avatar
WenmuZhou committed
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197

        if postprocess_params.det_algorithm not in SUPPORT_DET_MODEL:
            logger.error('det_algorithm must in {}'.format(SUPPORT_DET_MODEL))
            sys.exit(0)
        if postprocess_params.rec_algorithm not in SUPPORT_REC_MODEL:
            logger.error('rec_algorithm must in {}'.format(SUPPORT_REC_MODEL))
            sys.exit(0)

        postprocess_params.rec_char_dict_path = Path(
            __file__).parent / postprocess_params.rec_char_dict_path

        # init det_model and rec_model
        super().__init__(postprocess_params)

    def ocr(self, img, det=True, rec=True):
        """
        ocr with paddleocr
        args:
            img: img for ocr, support ndarray, img_path and list or ndarray
            det: use text detection or not, if false, only rec will be exec. default is True
            rec: use text recognition or not, if false, only det will be exec. default is True
        """
        assert isinstance(img, (np.ndarray, list, str))
        if isinstance(img, str):
            image_file = img
            img, flag = check_and_read_gif(image_file)
            if not flag:
                img = cv2.imread(image_file)
            if img is None:
                logger.error("error in loading image:{}".format(image_file))
                return None
        if det and rec:
            dt_boxes, rec_res = self.__call__(img)
            return [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
        elif det and not rec:
            dt_boxes, elapse = self.text_detector(img)
            if dt_boxes is None:
                return None
            return [box.tolist() for box in dt_boxes]
        else:
            if not isinstance(img, list):
                img = [img]
            rec_res, elapse = self.text_recognizer(img)
            return rec_res
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212


def main():
    # for com
    args = parse_args()
    image_file_list = get_image_file_list(args.image_dir)
    if len(image_file_list) == 0:
        logger.error('no images find in {}'.format(args.image_dir))
        return
    ocr_engine = PaddleOCR()
    for img_path in image_file_list:
        print(img_path)
        result = ocr_engine.ocr(img_path, det=args.det, rec=args.rec)
        for line in result:
            print(line)