program.py 14.3 KB
Newer Older
LDOUBLEV's avatar
LDOUBLEV committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

WenmuZhou's avatar
WenmuZhou committed
19
import os
LDOUBLEV's avatar
LDOUBLEV committed
20
21
22
import sys
import yaml
import time
WenmuZhou's avatar
WenmuZhou committed
23
24
25
26
27
28
import shutil
import paddle
import paddle.distributed as dist
from tqdm import tqdm
from argparse import ArgumentParser, RawDescriptionHelpFormatter

LDOUBLEV's avatar
LDOUBLEV committed
29
30
from ppocr.utils.stats import TrainingStats
from ppocr.utils.save_load import save_model
dyning's avatar
dyning committed
31
32
33
34
from ppocr.utils.utility import print_dict
from ppocr.utils.logging import get_logger
from ppocr.data import build_dataloader
import numpy as np
LDOUBLEV's avatar
LDOUBLEV committed
35

dyning's avatar
dyning committed
36

LDOUBLEV's avatar
LDOUBLEV committed
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
class ArgsParser(ArgumentParser):
    def __init__(self):
        super(ArgsParser, self).__init__(
            formatter_class=RawDescriptionHelpFormatter)
        self.add_argument("-c", "--config", help="configuration file to use")
        self.add_argument(
            "-o", "--opt", nargs='+', help="set configuration options")

    def parse_args(self, argv=None):
        args = super(ArgsParser, self).parse_args(argv)
        assert args.config is not None, \
            "Please specify --config=configure_file_path."
        args.opt = self._parse_opt(args.opt)
        return args

    def _parse_opt(self, opts):
        config = {}
        if not opts:
            return config
        for s in opts:
            s = s.strip()
            k, v = s.split('=')
            config[k] = yaml.load(v, Loader=yaml.Loader)
        return config


class AttrDict(dict):
    """Single level attribute dict, NOT recursive"""

    def __init__(self, **kwargs):
        super(AttrDict, self).__init__()
        super(AttrDict, self).update(kwargs)

    def __getattr__(self, key):
        if key in self:
            return self[key]
        raise AttributeError("object has no attribute '{}'".format(key))


global_config = AttrDict()

lyl120117's avatar
lyl120117 committed
78
79
default_config = {'Global': {'debug': False, }}

LDOUBLEV's avatar
LDOUBLEV committed
80
81
82
83
84
85
86
87

def load_config(file_path):
    """
    Load config from yml/yaml file.
    Args:
        file_path (str): Path of the config file to be loaded.
    Returns: global config
    """
lyl120117's avatar
lyl120117 committed
88
    merge_config(default_config)
LDOUBLEV's avatar
LDOUBLEV committed
89
90
    _, ext = os.path.splitext(file_path)
    assert ext in ['.yml', '.yaml'], "only support yaml files for now"
WenmuZhou's avatar
WenmuZhou committed
91
    merge_config(yaml.load(open(file_path, 'rb'), Loader=yaml.Loader))
LDOUBLEV's avatar
LDOUBLEV committed
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
    return global_config


def merge_config(config):
    """
    Merge config into global config.
    Args:
        config (dict): Config to be merged.
    Returns: global config
    """
    for key, value in config.items():
        if "." not in key:
            if isinstance(value, dict) and key in global_config:
                global_config[key].update(value)
            else:
                global_config[key] = value
        else:
            sub_keys = key.split('.')
tink2123's avatar
tink2123 committed
110
111
112
113
            assert (
                sub_keys[0] in global_config
            ), "the sub_keys can only be one of global_config: {}, but get: {}, please check your running command".format(
                global_config.keys(), sub_keys[0])
LDOUBLEV's avatar
LDOUBLEV committed
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
            cur = global_config[sub_keys[0]]
            for idx, sub_key in enumerate(sub_keys[1:]):
                if idx == len(sub_keys) - 2:
                    cur[sub_key] = value
                else:
                    cur = cur[sub_key]


def check_gpu(use_gpu):
    """
    Log error and exit when set use_gpu=true in paddlepaddle
    cpu version.
    """
    err = "Config use_gpu cannot be set as true while you are " \
          "using paddlepaddle cpu version ! \nPlease try: \n" \
          "\t1. Install paddlepaddle-gpu to run model on GPU \n" \
          "\t2. Set use_gpu as false in config file to run " \
          "model on CPU"

    try:
WenmuZhou's avatar
WenmuZhou committed
134
        if use_gpu and not paddle.is_compiled_with_cuda():
WenmuZhou's avatar
WenmuZhou committed
135
            print(err)
LDOUBLEV's avatar
LDOUBLEV committed
136
137
138
139
140
            sys.exit(1)
    except Exception as e:
        pass


WenmuZhou's avatar
WenmuZhou committed
141
def train(config,
dyning's avatar
dyning committed
142
143
144
          train_dataloader,
          valid_dataloader,
          device,
WenmuZhou's avatar
WenmuZhou committed
145
146
147
148
149
150
151
152
153
154
155
          model,
          loss_class,
          optimizer,
          lr_scheduler,
          post_process_class,
          eval_class,
          pre_best_model_dict,
          logger,
          vdl_writer=None):
    cal_metric_during_train = config['Global'].get('cal_metric_during_train',
                                                   False)
LDOUBLEV's avatar
LDOUBLEV committed
156
157
158
159
    log_smooth_window = config['Global']['log_smooth_window']
    epoch_num = config['Global']['epoch_num']
    print_batch_step = config['Global']['print_batch_step']
    eval_batch_step = config['Global']['eval_batch_step']
WenmuZhou's avatar
WenmuZhou committed
160

dyning's avatar
dyning committed
161
    global_step = 0
LDOUBLEV's avatar
LDOUBLEV committed
162
163
164
165
166
167
168
    start_eval_step = 0
    if type(eval_batch_step) == list and len(eval_batch_step) >= 2:
        start_eval_step = eval_batch_step[0]
        eval_batch_step = eval_batch_step[1]
        logger.info(
            "During the training process, after the {}th iteration, an evaluation is run every {} iterations".
            format(start_eval_step, eval_batch_step))
LDOUBLEV's avatar
LDOUBLEV committed
169
170
    save_epoch_step = config['Global']['save_epoch_step']
    save_model_dir = config['Global']['save_model_dir']
171
172
    if not os.path.exists(save_model_dir):
        os.makedirs(save_model_dir)
WenmuZhou's avatar
WenmuZhou committed
173
174
175
176
    main_indicator = eval_class.main_indicator
    best_model_dict = {main_indicator: 0}
    best_model_dict.update(pre_best_model_dict)
    train_stats = TrainingStats(log_smooth_window, ['lr'])
tink2123's avatar
tink2123 committed
177
    model_average = False
WenmuZhou's avatar
WenmuZhou committed
178
179
    model.train()

tink2123's avatar
tink2123 committed
180
181
    use_srn = config['Architecture']['algorithm'] == "SRN"

WenmuZhou's avatar
WenmuZhou committed
182
183
184
    if 'start_epoch' in best_model_dict:
        start_epoch = best_model_dict['start_epoch']
    else:
tink2123's avatar
tink2123 committed
185
        start_epoch = 1
WenmuZhou's avatar
WenmuZhou committed
186

tink2123's avatar
tink2123 committed
187
    for epoch in range(start_epoch, epoch_num + 1):
188
189
        train_dataloader = build_dataloader(
            config, 'Train', device, logger, seed=epoch)
WenmuZhou's avatar
WenmuZhou committed
190
191
192
193
        train_batch_cost = 0.0
        train_reader_cost = 0.0
        batch_sum = 0
        batch_start = time.time()
WenmuZhou's avatar
WenmuZhou committed
194
        for idx, batch in enumerate(train_dataloader):
WenmuZhou's avatar
WenmuZhou committed
195
            train_reader_cost += time.time() - batch_start
WenmuZhou's avatar
WenmuZhou committed
196
197
198
199
            if idx >= len(train_dataloader):
                break
            lr = optimizer.get_lr()
            images = batch[0]
tink2123's avatar
tink2123 committed
200
            if use_srn:
tink2123's avatar
tink2123 committed
201
202
                others = batch[-4:]
                preds = model(images, others)
tink2123's avatar
tink2123 committed
203
                model_average = True
tink2123's avatar
tink2123 committed
204
205
            else:
                preds = model(images)
WenmuZhou's avatar
WenmuZhou committed
206
207
            loss = loss_class(preds, batch)
            avg_loss = loss['loss']
dyning's avatar
dyning committed
208
            avg_loss.backward()
WenmuZhou's avatar
WenmuZhou committed
209
210
            optimizer.step()
            optimizer.clear_grad()
WenmuZhou's avatar
WenmuZhou committed
211
212
213
214

            train_batch_cost += time.time() - batch_start
            batch_sum += len(images)

dyning's avatar
dyning committed
215
216
            if not isinstance(lr_scheduler, float):
                lr_scheduler.step()
WenmuZhou's avatar
WenmuZhou committed
217
218
219
220
221
222

            # logger and visualdl
            stats = {k: v.numpy().mean() for k, v in loss.items()}
            stats['lr'] = lr
            train_stats.update(stats)

LDOUBLEV's avatar
LDOUBLEV committed
223
            if cal_metric_during_train:  # only rec and cls need
WenmuZhou's avatar
WenmuZhou committed
224
225
226
                batch = [item.numpy() for item in batch]
                post_result = post_process_class(preds, batch[1])
                eval_class(post_result, batch)
littletomatodonkey's avatar
fix doc  
littletomatodonkey committed
227
228
                metric = eval_class.get_metric()
                train_stats.update(metric)
WenmuZhou's avatar
WenmuZhou committed
229
230
231
232
233
234

            if vdl_writer is not None and dist.get_rank() == 0:
                for k, v in train_stats.get().items():
                    vdl_writer.add_scalar('TRAIN/{}'.format(k), v, global_step)
                vdl_writer.add_scalar('TRAIN/lr', lr, global_step)

dyning's avatar
dyning committed
235
236
            if dist.get_rank(
            ) == 0 and global_step > 0 and global_step % print_batch_step == 0:
WenmuZhou's avatar
WenmuZhou committed
237
                logs = train_stats.log()
WenmuZhou's avatar
WenmuZhou committed
238
                strs = 'epoch: [{}/{}], iter: {}, {}, reader_cost: {:.5f} s, batch_cost: {:.5f} s, samples: {}, ips: {:.5f}'.format(
WenmuZhou's avatar
WenmuZhou committed
239
240
241
                    epoch, epoch_num, global_step, logs, train_reader_cost /
                    print_batch_step, train_batch_cost / print_batch_step,
                    batch_sum, batch_sum / train_batch_cost)
WenmuZhou's avatar
WenmuZhou committed
242
                logger.info(strs)
WenmuZhou's avatar
WenmuZhou committed
243
244
245
                train_batch_cost = 0.0
                train_reader_cost = 0.0
                batch_sum = 0
WenmuZhou's avatar
WenmuZhou committed
246
247
248
            # eval
            if global_step > start_eval_step and \
                    (global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0:
tink2123's avatar
tink2123 committed
249
250
251
252
253
254
255
                if model_average:
                    Model_Average = paddle.incubate.optimizer.ModelAverage(
                        0.15,
                        parameters=model.parameters(),
                        min_average_window=10000,
                        max_average_window=15625)
                    Model_Average.apply()
tink2123's avatar
tink2123 committed
256
257
258
259
260
261
                cur_metric = eval(
                    model,
                    valid_dataloader,
                    post_process_class,
                    eval_class,
                    use_srn=use_srn)
LDOUBLEV's avatar
LDOUBLEV committed
262
263
264
                cur_metric_str = 'cur metric, {}'.format(', '.join(
                    ['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
                logger.info(cur_metric_str)
WenmuZhou's avatar
WenmuZhou committed
265
266
267

                # logger metric
                if vdl_writer is not None:
LDOUBLEV's avatar
LDOUBLEV committed
268
                    for k, v in cur_metric.items():
WenmuZhou's avatar
WenmuZhou committed
269
270
                        if isinstance(v, (float, int)):
                            vdl_writer.add_scalar('EVAL/{}'.format(k),
LDOUBLEV's avatar
LDOUBLEV committed
271
272
                                                  cur_metric[k], global_step)
                if cur_metric[main_indicator] >= best_model_dict[
WenmuZhou's avatar
WenmuZhou committed
273
                        main_indicator]:
LDOUBLEV's avatar
LDOUBLEV committed
274
                    best_model_dict.update(cur_metric)
WenmuZhou's avatar
WenmuZhou committed
275
276
277
278
279
280
281
282
283
284
                    best_model_dict['best_epoch'] = epoch
                    save_model(
                        model,
                        optimizer,
                        save_model_dir,
                        logger,
                        is_best=True,
                        prefix='best_accuracy',
                        best_model_dict=best_model_dict,
                        epoch=epoch)
LDOUBLEV's avatar
LDOUBLEV committed
285
                best_str = 'best metric, {}'.format(', '.join([
WenmuZhou's avatar
WenmuZhou committed
286
287
288
289
290
291
292
293
294
                    '{}: {}'.format(k, v) for k, v in best_model_dict.items()
                ]))
                logger.info(best_str)
                # logger best metric
                if vdl_writer is not None:
                    vdl_writer.add_scalar('EVAL/best_{}'.format(main_indicator),
                                          best_model_dict[main_indicator],
                                          global_step)
            global_step += 1
tink2123's avatar
tink2123 committed
295
            optimizer.clear_grad()
296
            batch_start = time.time()
WenmuZhou's avatar
WenmuZhou committed
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
        if dist.get_rank() == 0:
            save_model(
                model,
                optimizer,
                save_model_dir,
                logger,
                is_best=False,
                prefix='latest',
                best_model_dict=best_model_dict,
                epoch=epoch)
        if dist.get_rank() == 0 and epoch > 0 and epoch % save_epoch_step == 0:
            save_model(
                model,
                optimizer,
                save_model_dir,
                logger,
                is_best=False,
                prefix='iter_epoch_{}'.format(epoch),
                best_model_dict=best_model_dict,
                epoch=epoch)
LDOUBLEV's avatar
LDOUBLEV committed
317
    best_str = 'best metric, {}'.format(', '.join(
WenmuZhou's avatar
WenmuZhou committed
318
319
320
321
        ['{}: {}'.format(k, v) for k, v in best_model_dict.items()]))
    logger.info(best_str)
    if dist.get_rank() == 0 and vdl_writer is not None:
        vdl_writer.close()
LDOUBLEV's avatar
LDOUBLEV committed
322
323
324
    return


tink2123's avatar
tink2123 committed
325
326
def eval(model, valid_dataloader, post_process_class, eval_class,
         use_srn=False):
WenmuZhou's avatar
WenmuZhou committed
327
328
329
330
    model.eval()
    with paddle.no_grad():
        total_frame = 0.0
        total_time = 0.0
WenmuZhou's avatar
fix bug  
WenmuZhou committed
331
        pbar = tqdm(total=len(valid_dataloader), desc='eval model:')
WenmuZhou's avatar
WenmuZhou committed
332
333
334
        for idx, batch in enumerate(valid_dataloader):
            if idx >= len(valid_dataloader):
                break
WenmuZhou's avatar
fix bug  
WenmuZhou committed
335
            images = batch[0]
WenmuZhou's avatar
WenmuZhou committed
336
            start = time.time()
tink2123's avatar
tink2123 committed
337
338

            if use_srn:
xiaoting's avatar
xiaoting committed
339
340
341
342
                others = batch[-4:]
                preds = model(images, others)
            else:
                preds = model(images)
WenmuZhou's avatar
WenmuZhou committed
343
344
345
346
347
348
349

            batch = [item.numpy() for item in batch]
            # Obtain usable results from post-processing methods
            post_result = post_process_class(preds, batch[1])
            total_time += time.time() - start
            # Evaluate the results of the current batch
            eval_class(post_result, batch)
WenmuZhou's avatar
fix bug  
WenmuZhou committed
350
            pbar.update(1)
WenmuZhou's avatar
WenmuZhou committed
351
            total_frame += len(images)
LDOUBLEV's avatar
LDOUBLEV committed
352
353
        # Get final metric,eg. acc or hmean
        metric = eval_class.get_metric()
dyning's avatar
dyning committed
354

WenmuZhou's avatar
fix bug  
WenmuZhou committed
355
    pbar.close()
WenmuZhou's avatar
WenmuZhou committed
356
    model.train()
LDOUBLEV's avatar
LDOUBLEV committed
357
358
    metric['fps'] = total_frame / total_time
    return metric
licx's avatar
licx committed
359

tink2123's avatar
tink2123 committed
360

361
def preprocess(is_train=False):
licx's avatar
licx committed
362
363
364
365
366
367
368
369
    FLAGS = ArgsParser().parse_args()
    config = load_config(FLAGS.config)
    merge_config(FLAGS.opt)

    # check if set use_gpu=True in paddlepaddle cpu version
    use_gpu = config['Global']['use_gpu']
    check_gpu(use_gpu)

WenmuZhou's avatar
WenmuZhou committed
370
371
    alg = config['Architecture']['algorithm']
    assert alg in [
WenmuZhou's avatar
WenmuZhou committed
372
        'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'CLS'
WenmuZhou's avatar
WenmuZhou committed
373
    ]
licx's avatar
licx committed
374

WenmuZhou's avatar
WenmuZhou committed
375
376
    device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
    device = paddle.set_device(device)
dyning's avatar
dyning committed
377

dyning's avatar
dyning committed
378
    config['Global']['distributed'] = dist.get_world_size() != 1
379
380
381
382
383
384
385
386
387
388
389
    if is_train:
        # save_config
        save_model_dir = config['Global']['save_model_dir']
        os.makedirs(save_model_dir, exist_ok=True)
        with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f:
            yaml.dump(
                dict(config), f, default_flow_style=False, sort_keys=False)
        log_file = '{}/train.log'.format(save_model_dir)
    else:
        log_file = None
    logger = get_logger(name='root', log_file=log_file)
dyning's avatar
dyning committed
390
391
392
393
394
395
396
397
398
399
400
    if config['Global']['use_visualdl']:
        from visualdl import LogWriter
        vdl_writer_path = '{}/vdl/'.format(save_model_dir)
        os.makedirs(vdl_writer_path, exist_ok=True)
        vdl_writer = LogWriter(logdir=vdl_writer_path)
    else:
        vdl_writer = None
    print_dict(config, logger)
    logger.info('train with paddle {} and device {}'.format(paddle.__version__,
                                                            device))
    return config, device, logger, vdl_writer