infer_det.py 5.54 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
from copy import deepcopy
import json

LDOUBLEV's avatar
LDOUBLEV committed
23
24
25
26
27
import os
import sys
__dir__ = os.path.dirname(__file__)
sys.path.append(__dir__)
sys.path.append(os.path.join(__dir__, '..'))
28
29
30
31
32
33
34
35
36
37


def set_paddle_flags(**kwargs):
    for key, value in kwargs.items():
        if os.environ.get(key, None) is None:
            os.environ[key] = str(value)


# NOTE(paddle-dev): All of these flags should be
# set before `import paddle`. Otherwise, it would
LDOUBLEV's avatar
LDOUBLEV committed
38
# not take any effect.
39
40
41
42
43
set_paddle_flags(
    FLAGS_eager_delete_tensor_gb=0,  # enable GC to save memory
)

from paddle import fluid
LDOUBLEV's avatar
LDOUBLEV committed
44
from ppocr.utils.utility import create_module, get_image_file_list
45
46
47
import program
from ppocr.utils.save_load import init_model
from ppocr.data.reader_main import reader_main
LDOUBLEV's avatar
LDOUBLEV committed
48
import cv2
49
50
51
52
53

from ppocr.utils.utility import initial_logger
logger = initial_logger()


LDOUBLEV's avatar
LDOUBLEV committed
54
def draw_det_res(dt_boxes, config, img, img_name):
55
56
    if len(dt_boxes) > 0:
        import cv2
LDOUBLEV's avatar
LDOUBLEV committed
57
        src_im = img
58
59
60
        for box in dt_boxes:
            box = box.astype(np.int32).reshape((-1, 1, 2))
            cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
LDOUBLEV's avatar
LDOUBLEV committed
61
        save_det_path = os.path.dirname(config['Global'][
62
63
64
            'save_res_path']) + "/det_results/"
        if not os.path.exists(save_det_path):
            os.makedirs(save_det_path)
LDOUBLEV's avatar
LDOUBLEV committed
65
        save_path = os.path.join(save_det_path, os.path.basename(img_name))
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
        cv2.imwrite(save_path, src_im)
        logger.info("The detected Image saved in {}".format(save_path))


def main():
    config = program.load_config(FLAGS.config)
    program.merge_config(FLAGS.opt)
    print(config)

    # check if set use_gpu=True in paddlepaddle cpu version
    use_gpu = config['Global']['use_gpu']
    program.check_gpu(use_gpu)

    place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
    exe = fluid.Executor(place)

    det_model = create_module(config['Architecture']['function'])(params=config)

    startup_prog = fluid.Program()
    eval_prog = fluid.Program()
    with fluid.program_guard(eval_prog, startup_prog):
        with fluid.unique_name.guard():
            _, eval_outputs = det_model(mode="test")
            fetch_name_list = list(eval_outputs.keys())
            eval_fetch_list = [eval_outputs[v].name for v in fetch_name_list]

    eval_prog = eval_prog.clone(for_test=True)
    exe.run(startup_prog)

    # load checkpoints
    checkpoints = config['Global'].get('checkpoints')
    if checkpoints:
        path = checkpoints
        fluid.load(eval_prog, path, exe)
        logger.info("Finish initing model from {}".format(path))
    else:
        raise Exception("{} not exists!".format(checkpoints))

    save_res_path = config['Global']['save_res_path']
LDOUBLEV's avatar
LDOUBLEV committed
105
106
    if not os.path.exists(os.path.dirname(save_res_path)):
        os.makedirs(os.path.dirname(save_res_path))
107
    with open(save_res_path, "wb") as fout:
LDOUBLEV's avatar
LDOUBLEV committed
108

LDOUBLEV's avatar
LDOUBLEV committed
109
        test_reader = reader_main(config=config, mode='test')
110
111
112
113
114
115
116
117
118
119
120
121
        tackling_num = 0
        for data in test_reader():
            img_num = len(data)
            tackling_num = tackling_num + img_num
            logger.info("tackling_num:%d", tackling_num)
            img_list = []
            ratio_list = []
            img_name_list = []
            for ino in range(img_num):
                img_list.append(data[ino][0])
                ratio_list.append(data[ino][1])
                img_name_list.append(data[ino][2])
LDOUBLEV's avatar
LDOUBLEV committed
122

123
124
125
126
127
128
129
130
131
132
            img_list = np.concatenate(img_list, axis=0)
            outs = exe.run(eval_prog,\
                feed={'image': img_list},\
                fetch_list=eval_fetch_list)

            global_params = config['Global']
            postprocess_params = deepcopy(config["PostProcess"])
            postprocess_params.update(global_params)
            postprocess = create_module(postprocess_params['function'])\
                (params=postprocess_params)
LDOUBLEV's avatar
LDOUBLEV committed
133
134
135
136
137
            if config['Global']['algorithm'] == 'EAST':
                dic = {'f_score': outs[0], 'f_geo': outs[1]}
            elif config['Global']['algorithm'] == 'DB':
                dic = {'maps': outs[0]}
            else:
138
                raise Exception("only support algorithm: ['EAST', 'DB']")
LDOUBLEV's avatar
LDOUBLEV committed
139
            dt_boxes_list = postprocess(dic, ratio_list)
140
141
142
143
144
145
146
147
148
149
            for ino in range(img_num):
                dt_boxes = dt_boxes_list[ino]
                img_name = img_name_list[ino]
                dt_boxes_json = []
                for box in dt_boxes:
                    tmp_json = {"transcription": ""}
                    tmp_json['points'] = box.tolist()
                    dt_boxes_json.append(tmp_json)
                otstr = img_name + "\t" + json.dumps(dt_boxes_json) + "\n"
                fout.write(otstr.encode())
LDOUBLEV's avatar
LDOUBLEV committed
150
151
                src_img = cv2.imread(img_name)
                draw_det_res(dt_boxes, config, src_img, img_name)
152
153
154
155
156
157
158
159

    logger.info("success!")


if __name__ == '__main__':
    parser = program.ArgsParser()
    FLAGS = parser.parse_args()
    main()