export_model.py 3.17 KB
Newer Older
dyning's avatar
dyning committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
16
17
18
19
20
21
import os
import sys

__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))

dyning's avatar
dyning committed
22
23
24
25
26
27
28
29
import argparse

import paddle
from paddle.jit import to_static

from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import init_model
30
from ppocr.utils.logging import get_logger
WenmuZhou's avatar
WenmuZhou committed
31
from tools.program import load_config, merge_config, ArgsParser
dyning's avatar
dyning committed
32
33


tink2123's avatar
tink2123 committed
34
35
36
37
38
39
40
41
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("-c", "--config", help="configuration file to use")
    parser.add_argument(
        "-o", "--output_path", type=str, default='./output/infer/')
    return parser.parse_args()


dyning's avatar
dyning committed
42
def main():
43
    FLAGS = ArgsParser().parse_args()
dyning's avatar
dyning committed
44
    config = load_config(FLAGS.config)
45
    merge_config(FLAGS.opt)
46
    logger = get_logger()
dyning's avatar
dyning committed
47
    # build post process
tink2123's avatar
tink2123 committed
48

dyning's avatar
dyning committed
49
50
51
52
    post_process_class = build_post_process(config['PostProcess'],
                                            config['Global'])

    # build model
53
    # for rec algorithm
dyning's avatar
dyning committed
54
55
56
57
58
59
60
    if hasattr(post_process_class, 'character'):
        char_num = len(getattr(post_process_class, 'character'))
        config['Architecture']["Head"]['out_channels'] = char_num
    model = build_model(config['Architecture'])
    init_model(config, model, logger)
    model.eval()

61
    save_path = '{}/inference'.format(config['Global']['save_inference_dir'])
tink2123's avatar
tink2123 committed
62
63
64

    if config['Architecture']['algorithm'] == "SRN":
        other_shape = [
tink2123's avatar
tink2123 committed
65
            paddle.static.InputSpec(
tink2123's avatar
tink2123 committed
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
                shape=[None, 1, 64, 256], dtype='float32'), [
                    paddle.static.InputSpec(
                        shape=[None, 256, 1],
                        dtype="int64"), paddle.static.InputSpec(
                            shape=[None, 25, 1],
                            dtype="int64"), paddle.static.InputSpec(
                                shape=[None, 8, 25, 25], dtype="int64"),
                    paddle.static.InputSpec(
                        shape=[None, 8, 25, 25], dtype="int64")
                ]
        ]
        model = to_static(model, input_spec=other_shape)

    else:
        infer_shape = [3, 32, 100] if config['Architecture'][
            'model_type'] != "det" else [3, 640, 640]
        model = to_static(
            model,
            input_spec=[
                paddle.static.InputSpec(
                    shape=[None] + infer_shape, dtype='float32')
            ])

89
90
    paddle.jit.save(model, save_path)
    logger.info('inference model is saved to {}'.format(save_path))
dyning's avatar
dyning committed
91
92
93
94


if __name__ == "__main__":
    main()