train.py 7.3 KB
Newer Older
LDOUBLEV's avatar
LDOUBLEV committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import sys
WenmuZhou's avatar
WenmuZhou committed
21

22
__dir__ = os.path.dirname(os.path.abspath(__file__))
LDOUBLEV's avatar
LDOUBLEV committed
23
sys.path.append(__dir__)
littletomatodonkey's avatar
littletomatodonkey committed
24
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
LDOUBLEV's avatar
LDOUBLEV committed
25

WenmuZhou's avatar
WenmuZhou committed
26
27
28
import yaml
import paddle
import paddle.distributed as dist
LDOUBLEV's avatar
LDOUBLEV committed
29

WenmuZhou's avatar
WenmuZhou committed
30
from ppocr.data import build_dataloader
dyning's avatar
dyning committed
31
32
from ppocr.modeling.architectures import build_model
from ppocr.losses import build_loss
WenmuZhou's avatar
WenmuZhou committed
33
34
35
from ppocr.optimizer import build_optimizer
from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric
36
from ppocr.utils.save_load import load_model
WenmuZhou's avatar
WenmuZhou committed
37
from ppocr.utils.utility import set_seed
WenmuZhou's avatar
WenmuZhou committed
38
import tools.program as program
LDOUBLEV's avatar
LDOUBLEV committed
39

WenmuZhou's avatar
WenmuZhou committed
40
dist.get_world_size()
LDOUBLEV's avatar
LDOUBLEV committed
41
42


WenmuZhou's avatar
WenmuZhou committed
43
44
45
46
def main(config, device, logger, vdl_writer):
    # init dist environment
    if config['Global']['distributed']:
        dist.init_parallel_env()
LDOUBLEV's avatar
LDOUBLEV committed
47

WenmuZhou's avatar
WenmuZhou committed
48
    global_config = config['Global']
dyning's avatar
dyning committed
49

WenmuZhou's avatar
WenmuZhou committed
50
    # build dataloader
dyning's avatar
dyning committed
51
    train_dataloader = build_dataloader(config, 'Train', device, logger)
WenmuZhou's avatar
WenmuZhou committed
52
53
    if len(train_dataloader) == 0:
        logger.error(
54
55
56
57
            "No Images in train dataset, please ensure\n" +
            "\t1. The images num in the train label_file_list should be larger than or equal with batch size.\n"
            +
            "\t2. The annotation file and path in the configuration file are provided normally."
WenmuZhou's avatar
WenmuZhou committed
58
        )
WenmuZhou's avatar
WenmuZhou committed
59
        return
WenmuZhou's avatar
WenmuZhou committed
60

dyning's avatar
dyning committed
61
    if config['Eval']:
dyning's avatar
dyning committed
62
        valid_dataloader = build_dataloader(config, 'Eval', device, logger)
WenmuZhou's avatar
WenmuZhou committed
63
    else:
dyning's avatar
dyning committed
64
65
        valid_dataloader = None

WenmuZhou's avatar
WenmuZhou committed
66
    # build post process
dyning's avatar
dyning committed
67
68
69
    post_process_class = build_post_process(config['PostProcess'],
                                            global_config)

WenmuZhou's avatar
WenmuZhou committed
70
    # build model
WenmuZhou's avatar
WenmuZhou committed
71
    # for rec algorithm
WenmuZhou's avatar
WenmuZhou committed
72
    if hasattr(post_process_class, 'character'):
dyning's avatar
dyning committed
73
        char_num = len(getattr(post_process_class, 'character'))
littletomatodonkey's avatar
littletomatodonkey committed
74
75
76
        if config['Architecture']["algorithm"] in ["Distillation",
                                                   ]:  # distillation model
            for key in config['Architecture']["Models"]:
andyjpaddle's avatar
andyjpaddle committed
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
                if config['Architecture']['Models'][key]['Head'][
                        'name'] == 'MultiHead':  # for multi head
                    if config['PostProcess'][
                            'name'] == 'DistillationSARLabelDecode':
                        char_num = char_num - 2
                    # update SARLoss params
                    assert list(config['Loss']['loss_config_list'][-1].keys())[
                        0] == 'DistillationSARLoss'
                    config['Loss']['loss_config_list'][-1][
                        'DistillationSARLoss']['ignore_index'] = char_num + 1
                    out_channels_list = {}
                    out_channels_list['CTCLabelDecode'] = char_num
                    out_channels_list['SARLabelDecode'] = char_num + 2
                    config['Architecture']['Models'][key]['Head'][
                        'out_channels_list'] = out_channels_list
                else:
                    config['Architecture']["Models"][key]["Head"][
                        'out_channels'] = char_num
        elif config['Architecture']['Head'][
                'name'] == 'MultiHead':  # for multi head
            if config['PostProcess']['name'] == 'SARLabelDecode':
                char_num = char_num - 2
            # update SARLoss params
            assert list(config['Loss']['loss_config_list'][1].keys())[
                0] == 'SARLoss'
            if config['Loss']['loss_config_list'][1]['SARLoss'] is None:
                config['Loss']['loss_config_list'][1]['SARLoss'] = {
                    'ignore_index': char_num + 1
                }
            else:
                config['Loss']['loss_config_list'][1]['SARLoss'][
                    'ignore_index'] = char_num + 1
            out_channels_list = {}
            out_channels_list['CTCLabelDecode'] = char_num
            out_channels_list['SARLabelDecode'] = char_num + 2
            config['Architecture']['Head'][
                'out_channels_list'] = out_channels_list
littletomatodonkey's avatar
littletomatodonkey committed
114
115
116
        else:  # base rec model
            config['Architecture']["Head"]['out_channels'] = char_num

andyjpaddle's avatar
andyjpaddle committed
117
118
119
        if config['PostProcess']['name'] == 'SARLabelDecode':  # for SAR model
            config['Loss']['ignore_index'] = char_num - 1

WenmuZhou's avatar
WenmuZhou committed
120
121
122
123
    model = build_model(config['Architecture'])
    if config['Global']['distributed']:
        model = paddle.DataParallel(model)

dyning's avatar
dyning committed
124
125
    # build loss
    loss_class = build_loss(config['Loss'])
dyning's avatar
dyning committed
126

WenmuZhou's avatar
WenmuZhou committed
127
    # build optim
dyning's avatar
dyning committed
128
129
    optimizer, lr_scheduler = build_optimizer(
        config['Optimizer'],
WenmuZhou's avatar
WenmuZhou committed
130
        epochs=config['Global']['epoch_num'],
dyning's avatar
dyning committed
131
        step_each_epoch=len(train_dataloader),
Topdu's avatar
Topdu committed
132
        model=model)
WenmuZhou's avatar
WenmuZhou committed
133
134
135

    # build metric
    eval_class = build_metric(config['Metric'])
dyning's avatar
dyning committed
136
    # load pretrain model
137
138
    pre_best_model_dict = load_model(config, model, optimizer,
                                     config['Architecture']["model_type"])
139
140
141
142
    logger.info('train dataloader has {} iters'.format(len(train_dataloader)))
    if valid_dataloader is not None:
        logger.info('valid dataloader has {} iters'.format(
            len(valid_dataloader)))
stephon's avatar
stephon committed
143

144
    use_amp = config["Global"].get("use_amp", False)
stephon's avatar
stephon committed
145
146
147
148
149
150
    if use_amp:
        AMP_RELATED_FLAGS_SETTING = {
            'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
            'FLAGS_max_inplace_grad_add': 8,
        }
        paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING)
151
152
153
        scale_loss = config["Global"].get("scale_loss", 1.0)
        use_dynamic_loss_scaling = config["Global"].get(
            "use_dynamic_loss_scaling", False)
stephon's avatar
stephon committed
154
155
156
157
158
159
        scaler = paddle.amp.GradScaler(
            init_loss_scaling=scale_loss,
            use_dynamic_loss_scaling=use_dynamic_loss_scaling)
    else:
        scaler = None

WenmuZhou's avatar
WenmuZhou committed
160
    # start train
dyning's avatar
dyning committed
161
162
    program.train(config, train_dataloader, valid_dataloader, device, model,
                  loss_class, optimizer, lr_scheduler, post_process_class,
stephon's avatar
stephon committed
163
                  eval_class, pre_best_model_dict, logger, vdl_writer, scaler)
dyning's avatar
dyning committed
164
165
166


def test_reader(config, device, logger):
WenmuZhou's avatar
WenmuZhou committed
167
    loader = build_dataloader(config, 'Train', device, logger)
168
169
170
171
    import time
    starttime = time.time()
    count = 0
    try:
dyning's avatar
dyning committed
172
        for data in loader():
173
174
175
176
            count += 1
            if count % 1 == 0:
                batch_time = time.time() - starttime
                starttime = time.time()
WenmuZhou's avatar
WenmuZhou committed
177
178
                logger.info("reader: {}, {}, {}".format(
                    count, len(data[0]), batch_time))
179
    except Exception as e:
LDOUBLEV's avatar
LDOUBLEV committed
180
181
        logger.info(e)
    logger.info("finish reader: {}, Success!".format(count))
182

dyning's avatar
dyning committed
183

LDOUBLEV's avatar
LDOUBLEV committed
184
if __name__ == '__main__':
185
    config, device, logger, vdl_writer = program.preprocess(is_train=True)
WenmuZhou's avatar
WenmuZhou committed
186
187
    seed = config['Global']['seed'] if 'seed' in config['Global'] else 1024
    set_seed(seed)
dyning's avatar
dyning committed
188
    main(config, device, logger, vdl_writer)
WenmuZhou's avatar
WenmuZhou committed
189
    # test_reader(config, device, logger)