train.py 3.74 KB
Newer Older
LDOUBLEV's avatar
LDOUBLEV committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import sys
WenmuZhou's avatar
WenmuZhou committed
21

22
__dir__ = os.path.dirname(os.path.abspath(__file__))
LDOUBLEV's avatar
LDOUBLEV committed
23
sys.path.append(__dir__)
24
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
LDOUBLEV's avatar
LDOUBLEV committed
25

WenmuZhou's avatar
WenmuZhou committed
26
27
28
import yaml
import paddle
import paddle.distributed as dist
LDOUBLEV's avatar
LDOUBLEV committed
29

dyning's avatar
dyning committed
30
paddle.seed(2)
LDOUBLEV's avatar
LDOUBLEV committed
31

WenmuZhou's avatar
WenmuZhou committed
32
from ppocr.data import build_dataloader
dyning's avatar
dyning committed
33
34
from ppocr.modeling.architectures import build_model
from ppocr.losses import build_loss
WenmuZhou's avatar
WenmuZhou committed
35
36
37
38
39
40
from ppocr.optimizer import build_optimizer
from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric
from ppocr.utils.save_load import init_model
from ppocr.utils.utility import print_dict
import tools.program as program
LDOUBLEV's avatar
LDOUBLEV committed
41

WenmuZhou's avatar
WenmuZhou committed
42
dist.get_world_size()
LDOUBLEV's avatar
LDOUBLEV committed
43
44


WenmuZhou's avatar
WenmuZhou committed
45
46
47
48
def main(config, device, logger, vdl_writer):
    # init dist environment
    if config['Global']['distributed']:
        dist.init_parallel_env()
LDOUBLEV's avatar
LDOUBLEV committed
49

WenmuZhou's avatar
WenmuZhou committed
50
    global_config = config['Global']
dyning's avatar
dyning committed
51

WenmuZhou's avatar
WenmuZhou committed
52
    # build dataloader
dyning's avatar
dyning committed
53
    train_dataloader = build_dataloader(config, 'Train', device, logger)
dyning's avatar
dyning committed
54
    if config['Eval']:
dyning's avatar
dyning committed
55
        valid_dataloader = build_dataloader(config, 'Eval', device, logger)
WenmuZhou's avatar
WenmuZhou committed
56
    else:
dyning's avatar
dyning committed
57
58
        valid_dataloader = None

WenmuZhou's avatar
WenmuZhou committed
59
    # build post process
dyning's avatar
dyning committed
60
61
62
    post_process_class = build_post_process(config['PostProcess'],
                                            global_config)

WenmuZhou's avatar
WenmuZhou committed
63
    # build model
dyning's avatar
dyning committed
64
    #for rec algorithm
WenmuZhou's avatar
WenmuZhou committed
65
    if hasattr(post_process_class, 'character'):
dyning's avatar
dyning committed
66
67
        char_num = len(getattr(post_process_class, 'character'))
        config['Architecture']["Head"]['out_channels'] = char_num
WenmuZhou's avatar
WenmuZhou committed
68
69
70
71
    model = build_model(config['Architecture'])
    if config['Global']['distributed']:
        model = paddle.DataParallel(model)

dyning's avatar
dyning committed
72
73
    # build loss
    loss_class = build_loss(config['Loss'])
dyning's avatar
dyning committed
74

WenmuZhou's avatar
WenmuZhou committed
75
    # build optim
dyning's avatar
dyning committed
76
77
    optimizer, lr_scheduler = build_optimizer(
        config['Optimizer'],
WenmuZhou's avatar
WenmuZhou committed
78
        epochs=config['Global']['epoch_num'],
dyning's avatar
dyning committed
79
        step_each_epoch=len(train_dataloader),
WenmuZhou's avatar
WenmuZhou committed
80
81
82
83
        parameters=model.parameters())

    # build metric
    eval_class = build_metric(config['Metric'])
dyning's avatar
dyning committed
84

dyning's avatar
dyning committed
85
86
    # load pretrain model
    pre_best_model_dict = init_model(config, model, logger, optimizer)
WenmuZhou's avatar
WenmuZhou committed
87
88

    # start train
dyning's avatar
dyning committed
89
90
91
    program.train(config, train_dataloader, valid_dataloader, device, model,
                  loss_class, optimizer, lr_scheduler, post_process_class,
                  eval_class, pre_best_model_dict, logger, vdl_writer)
dyning's avatar
dyning committed
92
93
94
95


def test_reader(config, device, logger):
    loader = build_dataloader(config, 'Train', device)
dyning's avatar
dyning committed
96
    #     loader = build_dataloader(config, 'Eval', device)
97
98
99
100
    import time
    starttime = time.time()
    count = 0
    try:
dyning's avatar
dyning committed
101
        for data in loader():
102
103
104
105
            count += 1
            if count % 1 == 0:
                batch_time = time.time() - starttime
                starttime = time.time()
dyning's avatar
dyning committed
106
107
                logger.info("reader: {}, {}, {}".format(count,
                                                        len(data), batch_time))
108
    except Exception as e:
LDOUBLEV's avatar
LDOUBLEV committed
109
110
        logger.info(e)
    logger.info("finish reader: {}, Success!".format(count))
111

dyning's avatar
dyning committed
112

LDOUBLEV's avatar
LDOUBLEV committed
113
if __name__ == '__main__':
dyning's avatar
dyning committed
114
115
116
    config, device, logger, vdl_writer = program.preprocess()
    main(config, device, logger, vdl_writer)
#     test_reader(config, device, logger)