train.py 4.64 KB
Newer Older
LDOUBLEV's avatar
LDOUBLEV committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import sys
WenmuZhou's avatar
WenmuZhou committed
21

22
__dir__ = os.path.dirname(os.path.abspath(__file__))
LDOUBLEV's avatar
LDOUBLEV committed
23
sys.path.append(__dir__)
24
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
LDOUBLEV's avatar
LDOUBLEV committed
25

WenmuZhou's avatar
WenmuZhou committed
26
27
28
import yaml
import paddle
import paddle.distributed as dist
LDOUBLEV's avatar
LDOUBLEV committed
29

dyning's avatar
dyning committed
30
paddle.seed(2)
LDOUBLEV's avatar
LDOUBLEV committed
31

WenmuZhou's avatar
WenmuZhou committed
32
from ppocr.data import build_dataloader
dyning's avatar
dyning committed
33
34
from ppocr.modeling.architectures import build_model
from ppocr.losses import build_loss
WenmuZhou's avatar
WenmuZhou committed
35
36
37
from ppocr.optimizer import build_optimizer
from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric
Double_V's avatar
Double_V committed
38
from ppocr.utils.save_load import init_model, load_dygraph_params
WenmuZhou's avatar
WenmuZhou committed
39
import tools.program as program
LDOUBLEV's avatar
LDOUBLEV committed
40

WenmuZhou's avatar
WenmuZhou committed
41
dist.get_world_size()
LDOUBLEV's avatar
LDOUBLEV committed
42
43


LDOUBLEV's avatar
LDOUBLEV committed
44
def main(config, device, logger, vdl_writer):
WenmuZhou's avatar
WenmuZhou committed
45
46
47
    # init dist environment
    if config['Global']['distributed']:
        dist.init_parallel_env()
LDOUBLEV's avatar
LDOUBLEV committed
48

WenmuZhou's avatar
WenmuZhou committed
49
    global_config = config['Global']
dyning's avatar
dyning committed
50

WenmuZhou's avatar
WenmuZhou committed
51
    # build dataloader
dyning's avatar
dyning committed
52
    train_dataloader = build_dataloader(config, 'Train', device, logger)
WenmuZhou's avatar
WenmuZhou committed
53
54
    if len(train_dataloader) == 0:
        logger.error(
55
56
57
58
            "No Images in train dataset, please ensure\n" +
            "\t1. The images num in the train label_file_list should be larger than or equal with batch size.\n"
            +
            "\t2. The annotation file and path in the configuration file are provided normally."
WenmuZhou's avatar
WenmuZhou committed
59
        )
WenmuZhou's avatar
WenmuZhou committed
60
        return
WenmuZhou's avatar
WenmuZhou committed
61

dyning's avatar
dyning committed
62
    if config['Eval']:
dyning's avatar
dyning committed
63
        valid_dataloader = build_dataloader(config, 'Eval', device, logger)
WenmuZhou's avatar
WenmuZhou committed
64
    else:
dyning's avatar
dyning committed
65
66
        valid_dataloader = None

WenmuZhou's avatar
WenmuZhou committed
67
    # build post process
dyning's avatar
dyning committed
68
69
70
    post_process_class = build_post_process(config['PostProcess'],
                                            global_config)

WenmuZhou's avatar
WenmuZhou committed
71
    # build model
WenmuZhou's avatar
WenmuZhou committed
72
    # for rec algorithm
WenmuZhou's avatar
WenmuZhou committed
73
    if hasattr(post_process_class, 'character'):
dyning's avatar
dyning committed
74
        char_num = len(getattr(post_process_class, 'character'))
littletomatodonkey's avatar
littletomatodonkey committed
75
76
77
78
79
80
81
82
        if config['Architecture']["algorithm"] in ["Distillation",
                                                   ]:  # distillation model
            for key in config['Architecture']["Models"]:
                config['Architecture']["Models"][key]["Head"][
                    'out_channels'] = char_num
        else:  # base rec model
            config['Architecture']["Head"]['out_channels'] = char_num

WenmuZhou's avatar
WenmuZhou committed
83
84
85
86
    model = build_model(config['Architecture'])
    if config['Global']['distributed']:
        model = paddle.DataParallel(model)

dyning's avatar
dyning committed
87
88
    # build loss
    loss_class = build_loss(config['Loss'])
dyning's avatar
dyning committed
89

WenmuZhou's avatar
WenmuZhou committed
90
    # build optim
dyning's avatar
dyning committed
91
92
    optimizer, lr_scheduler = build_optimizer(
        config['Optimizer'],
WenmuZhou's avatar
WenmuZhou committed
93
        epochs=config['Global']['epoch_num'],
dyning's avatar
dyning committed
94
        step_each_epoch=len(train_dataloader),
WenmuZhou's avatar
WenmuZhou committed
95
96
97
98
        parameters=model.parameters())

    # build metric
    eval_class = build_metric(config['Metric'])
dyning's avatar
dyning committed
99
    # load pretrain model
Double_V's avatar
Double_V committed
100
    pre_best_model_dict = load_dygraph_params(config, model, logger, optimizer)
101
102
103
104
    logger.info('train dataloader has {} iters'.format(len(train_dataloader)))
    if valid_dataloader is not None:
        logger.info('valid dataloader has {} iters'.format(
            len(valid_dataloader)))
WenmuZhou's avatar
WenmuZhou committed
105
    # start train
dyning's avatar
dyning committed
106
107
    program.train(config, train_dataloader, valid_dataloader, device, model,
                  loss_class, optimizer, lr_scheduler, post_process_class,
LDOUBLEV's avatar
LDOUBLEV committed
108
                  eval_class, pre_best_model_dict, logger, vdl_writer)
dyning's avatar
dyning committed
109
110
111


def test_reader(config, device, logger):
WenmuZhou's avatar
WenmuZhou committed
112
    loader = build_dataloader(config, 'Train', device, logger)
113
114
115
116
    import time
    starttime = time.time()
    count = 0
    try:
dyning's avatar
dyning committed
117
        for data in loader():
118
119
120
121
            count += 1
            if count % 1 == 0:
                batch_time = time.time() - starttime
                starttime = time.time()
WenmuZhou's avatar
WenmuZhou committed
122
123
                logger.info("reader: {}, {}, {}".format(
                    count, len(data[0]), batch_time))
124
    except Exception as e:
LDOUBLEV's avatar
LDOUBLEV committed
125
126
        logger.info(e)
    logger.info("finish reader: {}, Success!".format(count))
127

dyning's avatar
dyning committed
128

LDOUBLEV's avatar
LDOUBLEV committed
129
if __name__ == '__main__':
LDOUBLEV's avatar
LDOUBLEV committed
130
    config, device, logger, vdl_writer = program.preprocess(
LDOUBLEV's avatar
LDOUBLEV committed
131
        is_train=True)
LDOUBLEV's avatar
LDOUBLEV committed
132
133
    logger.info(f"config.profiler_options: {config.profiler_options}")
    main(config, device, logger, vdl_writer)
WenmuZhou's avatar
WenmuZhou committed
134
    # test_reader(config, device, logger)