program.py 14.9 KB
Newer Older
LDOUBLEV's avatar
LDOUBLEV committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

WenmuZhou's avatar
WenmuZhou committed
19
import os
LDOUBLEV's avatar
LDOUBLEV committed
20
import sys
21
import platform
LDOUBLEV's avatar
LDOUBLEV committed
22
23
import yaml
import time
WenmuZhou's avatar
WenmuZhou committed
24
25
26
27
28
29
import shutil
import paddle
import paddle.distributed as dist
from tqdm import tqdm
from argparse import ArgumentParser, RawDescriptionHelpFormatter

LDOUBLEV's avatar
LDOUBLEV committed
30
31
from ppocr.utils.stats import TrainingStats
from ppocr.utils.save_load import save_model
dyning's avatar
dyning committed
32
33
34
35
from ppocr.utils.utility import print_dict
from ppocr.utils.logging import get_logger
from ppocr.data import build_dataloader
import numpy as np
LDOUBLEV's avatar
LDOUBLEV committed
36

dyning's avatar
dyning committed
37

LDOUBLEV's avatar
LDOUBLEV committed
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
class ArgsParser(ArgumentParser):
    def __init__(self):
        super(ArgsParser, self).__init__(
            formatter_class=RawDescriptionHelpFormatter)
        self.add_argument("-c", "--config", help="configuration file to use")
        self.add_argument(
            "-o", "--opt", nargs='+', help="set configuration options")

    def parse_args(self, argv=None):
        args = super(ArgsParser, self).parse_args(argv)
        assert args.config is not None, \
            "Please specify --config=configure_file_path."
        args.opt = self._parse_opt(args.opt)
        return args

    def _parse_opt(self, opts):
        config = {}
        if not opts:
            return config
        for s in opts:
            s = s.strip()
            k, v = s.split('=')
            config[k] = yaml.load(v, Loader=yaml.Loader)
        return config


class AttrDict(dict):
    """Single level attribute dict, NOT recursive"""

    def __init__(self, **kwargs):
        super(AttrDict, self).__init__()
        super(AttrDict, self).update(kwargs)

    def __getattr__(self, key):
        if key in self:
            return self[key]
        raise AttributeError("object has no attribute '{}'".format(key))


global_config = AttrDict()

lyl120117's avatar
lyl120117 committed
79
80
default_config = {'Global': {'debug': False, }}

LDOUBLEV's avatar
LDOUBLEV committed
81
82
83
84
85
86
87
88

def load_config(file_path):
    """
    Load config from yml/yaml file.
    Args:
        file_path (str): Path of the config file to be loaded.
    Returns: global config
    """
lyl120117's avatar
lyl120117 committed
89
    merge_config(default_config)
LDOUBLEV's avatar
LDOUBLEV committed
90
91
    _, ext = os.path.splitext(file_path)
    assert ext in ['.yml', '.yaml'], "only support yaml files for now"
WenmuZhou's avatar
WenmuZhou committed
92
    merge_config(yaml.load(open(file_path, 'rb'), Loader=yaml.Loader))
LDOUBLEV's avatar
LDOUBLEV committed
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
    return global_config


def merge_config(config):
    """
    Merge config into global config.
    Args:
        config (dict): Config to be merged.
    Returns: global config
    """
    for key, value in config.items():
        if "." not in key:
            if isinstance(value, dict) and key in global_config:
                global_config[key].update(value)
            else:
                global_config[key] = value
        else:
            sub_keys = key.split('.')
tink2123's avatar
tink2123 committed
111
112
113
114
            assert (
                sub_keys[0] in global_config
            ), "the sub_keys can only be one of global_config: {}, but get: {}, please check your running command".format(
                global_config.keys(), sub_keys[0])
LDOUBLEV's avatar
LDOUBLEV committed
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
            cur = global_config[sub_keys[0]]
            for idx, sub_key in enumerate(sub_keys[1:]):
                if idx == len(sub_keys) - 2:
                    cur[sub_key] = value
                else:
                    cur = cur[sub_key]


def check_gpu(use_gpu):
    """
    Log error and exit when set use_gpu=true in paddlepaddle
    cpu version.
    """
    err = "Config use_gpu cannot be set as true while you are " \
          "using paddlepaddle cpu version ! \nPlease try: \n" \
          "\t1. Install paddlepaddle-gpu to run model on GPU \n" \
          "\t2. Set use_gpu as false in config file to run " \
          "model on CPU"

    try:
WenmuZhou's avatar
WenmuZhou committed
135
        if use_gpu and not paddle.is_compiled_with_cuda():
WenmuZhou's avatar
WenmuZhou committed
136
            print(err)
LDOUBLEV's avatar
LDOUBLEV committed
137
138
139
140
141
            sys.exit(1)
    except Exception as e:
        pass


WenmuZhou's avatar
WenmuZhou committed
142
def train(config,
dyning's avatar
dyning committed
143
144
145
          train_dataloader,
          valid_dataloader,
          device,
WenmuZhou's avatar
WenmuZhou committed
146
147
148
149
150
151
152
153
154
155
156
          model,
          loss_class,
          optimizer,
          lr_scheduler,
          post_process_class,
          eval_class,
          pre_best_model_dict,
          logger,
          vdl_writer=None):
    cal_metric_during_train = config['Global'].get('cal_metric_during_train',
                                                   False)
LDOUBLEV's avatar
LDOUBLEV committed
157
158
159
160
    log_smooth_window = config['Global']['log_smooth_window']
    epoch_num = config['Global']['epoch_num']
    print_batch_step = config['Global']['print_batch_step']
    eval_batch_step = config['Global']['eval_batch_step']
WenmuZhou's avatar
WenmuZhou committed
161

dyning's avatar
dyning committed
162
    global_step = 0
163
164
    if 'global_step' in pre_best_model_dict:
        global_step = pre_best_model_dict['global_step']
LDOUBLEV's avatar
LDOUBLEV committed
165
166
167
168
    start_eval_step = 0
    if type(eval_batch_step) == list and len(eval_batch_step) >= 2:
        start_eval_step = eval_batch_step[0]
        eval_batch_step = eval_batch_step[1]
WenmuZhou's avatar
WenmuZhou committed
169
170
171
172
173
        if len(valid_dataloader) == 0:
            logger.info(
                'No Images in eval dataset, evaluation during training will be disabled'
            )
            start_eval_step = 1e111
LDOUBLEV's avatar
LDOUBLEV committed
174
175
176
        logger.info(
            "During the training process, after the {}th iteration, an evaluation is run every {} iterations".
            format(start_eval_step, eval_batch_step))
LDOUBLEV's avatar
LDOUBLEV committed
177
178
    save_epoch_step = config['Global']['save_epoch_step']
    save_model_dir = config['Global']['save_model_dir']
179
180
    if not os.path.exists(save_model_dir):
        os.makedirs(save_model_dir)
WenmuZhou's avatar
WenmuZhou committed
181
182
183
184
    main_indicator = eval_class.main_indicator
    best_model_dict = {main_indicator: 0}
    best_model_dict.update(pre_best_model_dict)
    train_stats = TrainingStats(log_smooth_window, ['lr'])
tink2123's avatar
tink2123 committed
185
    model_average = False
WenmuZhou's avatar
WenmuZhou committed
186
187
    model.train()

tink2123's avatar
tink2123 committed
188
189
    use_srn = config['Architecture']['algorithm'] == "SRN"

WenmuZhou's avatar
WenmuZhou committed
190
191
192
    if 'start_epoch' in best_model_dict:
        start_epoch = best_model_dict['start_epoch']
    else:
tink2123's avatar
tink2123 committed
193
        start_epoch = 1
WenmuZhou's avatar
WenmuZhou committed
194

tink2123's avatar
tink2123 committed
195
    for epoch in range(start_epoch, epoch_num + 1):
196
197
        train_dataloader = build_dataloader(
            config, 'Train', device, logger, seed=epoch)
WenmuZhou's avatar
WenmuZhou committed
198
199
200
201
        train_batch_cost = 0.0
        train_reader_cost = 0.0
        batch_sum = 0
        batch_start = time.time()
202
        for idx, batch in enumerate(train_dataloader()):
WenmuZhou's avatar
WenmuZhou committed
203
            train_reader_cost += time.time() - batch_start
WenmuZhou's avatar
WenmuZhou committed
204
205
            lr = optimizer.get_lr()
            images = batch[0]
tink2123's avatar
tink2123 committed
206
            if use_srn:
tink2123's avatar
tink2123 committed
207
208
                others = batch[-4:]
                preds = model(images, others)
tink2123's avatar
tink2123 committed
209
                model_average = True
tink2123's avatar
tink2123 committed
210
211
            else:
                preds = model(images)
WenmuZhou's avatar
WenmuZhou committed
212
213
            loss = loss_class(preds, batch)
            avg_loss = loss['loss']
dyning's avatar
dyning committed
214
            avg_loss.backward()
WenmuZhou's avatar
WenmuZhou committed
215
216
            optimizer.step()
            optimizer.clear_grad()
WenmuZhou's avatar
WenmuZhou committed
217
218
219
220

            train_batch_cost += time.time() - batch_start
            batch_sum += len(images)

dyning's avatar
dyning committed
221
222
            if not isinstance(lr_scheduler, float):
                lr_scheduler.step()
WenmuZhou's avatar
WenmuZhou committed
223
224
225
226
227
228

            # logger and visualdl
            stats = {k: v.numpy().mean() for k, v in loss.items()}
            stats['lr'] = lr
            train_stats.update(stats)

LDOUBLEV's avatar
LDOUBLEV committed
229
            if cal_metric_during_train:  # only rec and cls need
WenmuZhou's avatar
WenmuZhou committed
230
231
232
                batch = [item.numpy() for item in batch]
                post_result = post_process_class(preds, batch[1])
                eval_class(post_result, batch)
littletomatodonkey's avatar
fix doc  
littletomatodonkey committed
233
234
                metric = eval_class.get_metric()
                train_stats.update(metric)
WenmuZhou's avatar
WenmuZhou committed
235
236
237
238
239
240

            if vdl_writer is not None and dist.get_rank() == 0:
                for k, v in train_stats.get().items():
                    vdl_writer.add_scalar('TRAIN/{}'.format(k), v, global_step)
                vdl_writer.add_scalar('TRAIN/lr', lr, global_step)

241
242
243
            if dist.get_rank() == 0 and (
                (global_step > 0 and global_step % print_batch_step == 0) or
                (idx >= len(train_dataloader) - 1)):
WenmuZhou's avatar
WenmuZhou committed
244
                logs = train_stats.log()
WenmuZhou's avatar
WenmuZhou committed
245
                strs = 'epoch: [{}/{}], iter: {}, {}, reader_cost: {:.5f} s, batch_cost: {:.5f} s, samples: {}, ips: {:.5f}'.format(
WenmuZhou's avatar
WenmuZhou committed
246
247
248
                    epoch, epoch_num, global_step, logs, train_reader_cost /
                    print_batch_step, train_batch_cost / print_batch_step,
                    batch_sum, batch_sum / train_batch_cost)
WenmuZhou's avatar
WenmuZhou committed
249
                logger.info(strs)
WenmuZhou's avatar
WenmuZhou committed
250
251
252
                train_batch_cost = 0.0
                train_reader_cost = 0.0
                batch_sum = 0
WenmuZhou's avatar
WenmuZhou committed
253
254
255
            # eval
            if global_step > start_eval_step and \
                    (global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0:
tink2123's avatar
tink2123 committed
256
257
258
259
260
261
262
                if model_average:
                    Model_Average = paddle.incubate.optimizer.ModelAverage(
                        0.15,
                        parameters=model.parameters(),
                        min_average_window=10000,
                        max_average_window=15625)
                    Model_Average.apply()
tink2123's avatar
tink2123 committed
263
264
265
266
267
268
                cur_metric = eval(
                    model,
                    valid_dataloader,
                    post_process_class,
                    eval_class,
                    use_srn=use_srn)
LDOUBLEV's avatar
LDOUBLEV committed
269
270
271
                cur_metric_str = 'cur metric, {}'.format(', '.join(
                    ['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
                logger.info(cur_metric_str)
WenmuZhou's avatar
WenmuZhou committed
272
273
274

                # logger metric
                if vdl_writer is not None:
LDOUBLEV's avatar
LDOUBLEV committed
275
                    for k, v in cur_metric.items():
WenmuZhou's avatar
WenmuZhou committed
276
277
                        if isinstance(v, (float, int)):
                            vdl_writer.add_scalar('EVAL/{}'.format(k),
LDOUBLEV's avatar
LDOUBLEV committed
278
279
                                                  cur_metric[k], global_step)
                if cur_metric[main_indicator] >= best_model_dict[
WenmuZhou's avatar
WenmuZhou committed
280
                        main_indicator]:
LDOUBLEV's avatar
LDOUBLEV committed
281
                    best_model_dict.update(cur_metric)
WenmuZhou's avatar
WenmuZhou committed
282
283
284
285
286
287
288
289
290
                    best_model_dict['best_epoch'] = epoch
                    save_model(
                        model,
                        optimizer,
                        save_model_dir,
                        logger,
                        is_best=True,
                        prefix='best_accuracy',
                        best_model_dict=best_model_dict,
291
292
                        epoch=epoch,
                        global_step=global_step)
LDOUBLEV's avatar
LDOUBLEV committed
293
                best_str = 'best metric, {}'.format(', '.join([
WenmuZhou's avatar
WenmuZhou committed
294
295
296
297
298
299
300
301
302
                    '{}: {}'.format(k, v) for k, v in best_model_dict.items()
                ]))
                logger.info(best_str)
                # logger best metric
                if vdl_writer is not None:
                    vdl_writer.add_scalar('EVAL/best_{}'.format(main_indicator),
                                          best_model_dict[main_indicator],
                                          global_step)
            global_step += 1
tink2123's avatar
tink2123 committed
303
            optimizer.clear_grad()
304
            batch_start = time.time()
WenmuZhou's avatar
WenmuZhou committed
305
306
307
308
309
310
311
312
313
        if dist.get_rank() == 0:
            save_model(
                model,
                optimizer,
                save_model_dir,
                logger,
                is_best=False,
                prefix='latest',
                best_model_dict=best_model_dict,
314
315
                epoch=epoch,
                global_step=global_step)
WenmuZhou's avatar
WenmuZhou committed
316
317
318
319
320
321
322
323
324
        if dist.get_rank() == 0 and epoch > 0 and epoch % save_epoch_step == 0:
            save_model(
                model,
                optimizer,
                save_model_dir,
                logger,
                is_best=False,
                prefix='iter_epoch_{}'.format(epoch),
                best_model_dict=best_model_dict,
325
326
                epoch=epoch,
                global_step=global_step)
LDOUBLEV's avatar
LDOUBLEV committed
327
    best_str = 'best metric, {}'.format(', '.join(
WenmuZhou's avatar
WenmuZhou committed
328
329
330
331
        ['{}: {}'.format(k, v) for k, v in best_model_dict.items()]))
    logger.info(best_str)
    if dist.get_rank() == 0 and vdl_writer is not None:
        vdl_writer.close()
LDOUBLEV's avatar
LDOUBLEV committed
332
333
334
    return


tink2123's avatar
tink2123 committed
335
336
def eval(model, valid_dataloader, post_process_class, eval_class,
         use_srn=False):
WenmuZhou's avatar
WenmuZhou committed
337
338
339
340
    model.eval()
    with paddle.no_grad():
        total_frame = 0.0
        total_time = 0.0
WenmuZhou's avatar
fix bug  
WenmuZhou committed
341
        pbar = tqdm(total=len(valid_dataloader), desc='eval model:')
342
343
        max_iter = len(valid_dataloader) - 1 if platform.system(
        ) == "Windows" else len(valid_dataloader)
WenmuZhou's avatar
WenmuZhou committed
344
        for idx, batch in enumerate(valid_dataloader):
345
            if idx >= max_iter:
WenmuZhou's avatar
WenmuZhou committed
346
                break
WenmuZhou's avatar
fix bug  
WenmuZhou committed
347
            images = batch[0]
WenmuZhou's avatar
WenmuZhou committed
348
            start = time.time()
tink2123's avatar
tink2123 committed
349
350

            if use_srn:
xiaoting's avatar
xiaoting committed
351
352
353
354
                others = batch[-4:]
                preds = model(images, others)
            else:
                preds = model(images)
WenmuZhou's avatar
WenmuZhou committed
355
356
357
358
359
360
361

            batch = [item.numpy() for item in batch]
            # Obtain usable results from post-processing methods
            post_result = post_process_class(preds, batch[1])
            total_time += time.time() - start
            # Evaluate the results of the current batch
            eval_class(post_result, batch)
WenmuZhou's avatar
fix bug  
WenmuZhou committed
362
            pbar.update(1)
WenmuZhou's avatar
WenmuZhou committed
363
            total_frame += len(images)
LDOUBLEV's avatar
LDOUBLEV committed
364
365
        # Get final metric,eg. acc or hmean
        metric = eval_class.get_metric()
dyning's avatar
dyning committed
366

WenmuZhou's avatar
fix bug  
WenmuZhou committed
367
    pbar.close()
WenmuZhou's avatar
WenmuZhou committed
368
    model.train()
LDOUBLEV's avatar
LDOUBLEV committed
369
370
    metric['fps'] = total_frame / total_time
    return metric
licx's avatar
licx committed
371

tink2123's avatar
tink2123 committed
372

373
def preprocess(is_train=False):
licx's avatar
licx committed
374
375
376
377
378
379
380
381
    FLAGS = ArgsParser().parse_args()
    config = load_config(FLAGS.config)
    merge_config(FLAGS.opt)

    # check if set use_gpu=True in paddlepaddle cpu version
    use_gpu = config['Global']['use_gpu']
    check_gpu(use_gpu)

WenmuZhou's avatar
WenmuZhou committed
382
383
    alg = config['Architecture']['algorithm']
    assert alg in [
Jethong's avatar
Jethong committed
384
        'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
385
        'CLS', 'PGNet'
WenmuZhou's avatar
WenmuZhou committed
386
    ]
licx's avatar
licx committed
387

WenmuZhou's avatar
WenmuZhou committed
388
389
    device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
    device = paddle.set_device(device)
dyning's avatar
dyning committed
390

dyning's avatar
dyning committed
391
    config['Global']['distributed'] = dist.get_world_size() != 1
392
393
394
395
396
397
398
399
400
401
402
    if is_train:
        # save_config
        save_model_dir = config['Global']['save_model_dir']
        os.makedirs(save_model_dir, exist_ok=True)
        with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f:
            yaml.dump(
                dict(config), f, default_flow_style=False, sort_keys=False)
        log_file = '{}/train.log'.format(save_model_dir)
    else:
        log_file = None
    logger = get_logger(name='root', log_file=log_file)
dyning's avatar
dyning committed
403
404
    if config['Global']['use_visualdl']:
        from visualdl import LogWriter
LDOUBLEV's avatar
fix bug  
LDOUBLEV committed
405
        save_model_dir = config['Global']['save_model_dir']
dyning's avatar
dyning committed
406
407
408
409
410
411
412
413
414
        vdl_writer_path = '{}/vdl/'.format(save_model_dir)
        os.makedirs(vdl_writer_path, exist_ok=True)
        vdl_writer = LogWriter(logdir=vdl_writer_path)
    else:
        vdl_writer = None
    print_dict(config, logger)
    logger.info('train with paddle {} and device {}'.format(paddle.__version__,
                                                            device))
    return config, device, logger, vdl_writer