base_trainer.py 9.06 KB
Newer Older
wangsen's avatar
wangsen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
# -*- coding: utf-8 -*-
# @Time    : 2019/8/23 21:50
# @Author  : zhoujun

import os
import pathlib
import shutil
from pprint import pformat

import anyconfig
import paddle
import numpy as np
import random
from paddle.jit import to_static
from paddle.static import InputSpec

from utils import setup_logger


class BaseTrainer:
    def __init__(self,
                 config,
                 model,
                 criterion,
                 train_loader,
                 validate_loader,
                 metric_cls,
                 post_process=None):
        config['trainer']['output_dir'] = os.path.join(
            str(pathlib.Path(os.path.abspath(__name__)).parent),
            config['trainer']['output_dir'])
        config['name'] = config['name'] + '_' + model.name
        self.save_dir = config['trainer']['output_dir']
        self.checkpoint_dir = os.path.join(self.save_dir, 'checkpoint')

        os.makedirs(self.checkpoint_dir, exist_ok=True)

        self.global_step = 0
        self.start_epoch = 0
        self.config = config
        self.criterion = criterion
        # logger and tensorboard
        self.visualdl_enable = self.config['trainer'].get('visual_dl', False)
        self.epochs = self.config['trainer']['epochs']
        self.log_iter = self.config['trainer']['log_iter']
        if paddle.distributed.get_rank() == 0:
            anyconfig.dump(config, os.path.join(self.save_dir, 'config.yaml'))
            self.logger = setup_logger(os.path.join(self.save_dir, 'train.log'))
            self.logger_info(pformat(self.config))

        self.model = self.apply_to_static(model)

        # device
        if paddle.device.cuda.device_count(
        ) > 0 and paddle.device.is_compiled_with_cuda():
            self.with_cuda = True
            random.seed(self.config['trainer']['seed'])
            np.random.seed(self.config['trainer']['seed'])
            paddle.seed(self.config['trainer']['seed'])
        else:
            self.with_cuda = False
        self.logger_info('train with and paddle {}'.format(paddle.__version__))
        # metrics
        self.metrics = {
            'recall': 0,
            'precision': 0,
            'hmean': 0,
            'train_loss': float('inf'),
            'best_model_epoch': 0
        }

        self.train_loader = train_loader
        if validate_loader is not None:
            assert post_process is not None and metric_cls is not None
        self.validate_loader = validate_loader
        self.post_process = post_process
        self.metric_cls = metric_cls
        self.train_loader_len = len(train_loader)

        if self.validate_loader is not None:
            self.logger_info(
                'train dataset has {} samples,{} in dataloader, validate dataset has {} samples,{} in dataloader'.
                format(
                    len(self.train_loader.dataset), self.train_loader_len,
                    len(self.validate_loader.dataset),
                    len(self.validate_loader)))
        else:
            self.logger_info(
                'train dataset has {} samples,{} in dataloader'.format(
                    len(self.train_loader.dataset), self.train_loader_len))

        self._initialize_scheduler()

        self._initialize_optimizer()

        # resume or finetune
        if self.config['trainer']['resume_checkpoint'] != '':
            self._load_checkpoint(
                self.config['trainer']['resume_checkpoint'], resume=True)
        elif self.config['trainer']['finetune_checkpoint'] != '':
            self._load_checkpoint(
                self.config['trainer']['finetune_checkpoint'], resume=False)

        if self.visualdl_enable and paddle.distributed.get_rank() == 0:
            from visualdl import LogWriter
            self.writer = LogWriter(self.save_dir)

        # 混合精度训练
        self.amp = self.config.get('amp', None)
        if self.amp == 'None':
            self.amp = None
        if self.amp:
            self.amp['scaler'] = paddle.amp.GradScaler(
                init_loss_scaling=self.amp.get("scale_loss", 1024),
                use_dynamic_loss_scaling=self.amp.get(
                    'use_dynamic_loss_scaling', True))
            self.model, self.optimizer = paddle.amp.decorate(
                models=self.model,
                optimizers=self.optimizer,
                level=self.amp.get('amp_level', 'O2'))

        # 分布式训练
        if paddle.device.cuda.device_count() > 1:
            self.model = paddle.DataParallel(self.model)
        # make inverse Normalize
        self.UN_Normalize = False
        for t in self.config['dataset']['train']['dataset']['args'][
                'transforms']:
            if t['type'] == 'Normalize':
                self.normalize_mean = t['args']['mean']
                self.normalize_std = t['args']['std']
                self.UN_Normalize = True

    def apply_to_static(self, model):
        support_to_static = self.config['trainer'].get('to_static', False)
        if support_to_static:
            specs = None
            print('static')
            specs = [InputSpec([None, 3, -1, -1])]
            model = to_static(model, input_spec=specs)
            self.logger_info(
                "Successfully to apply @to_static with specs: {}".format(specs))
        return model

    def train(self):
        """
        Full training logic
        """
        for epoch in range(self.start_epoch + 1, self.epochs + 1):
            self.epoch_result = self._train_epoch(epoch)
            self._on_epoch_finish()
        if paddle.distributed.get_rank() == 0 and self.visualdl_enable:
            self.writer.close()
        self._on_train_finish()

    def _train_epoch(self, epoch):
        """
        Training logic for an epoch

        :param epoch: Current epoch number
        """
        raise NotImplementedError

    def _eval(self, epoch):
        """
        eval logic for an epoch

        :param epoch: Current epoch number
        """
        raise NotImplementedError

    def _on_epoch_finish(self):
        raise NotImplementedError

    def _on_train_finish(self):
        raise NotImplementedError

    def _save_checkpoint(self, epoch, file_name):
        """
        Saving checkpoints

        :param epoch: current epoch number
        :param log: logging information of the epoch
        :param save_best: if True, rename the saved checkpoint to 'model_best.pth.tar'
        """
        state_dict = self.model.state_dict()
        state = {
            'epoch': epoch,
            'global_step': self.global_step,
            'state_dict': state_dict,
            'optimizer': self.optimizer.state_dict(),
            'config': self.config,
            'metrics': self.metrics
        }
        filename = os.path.join(self.checkpoint_dir, file_name)
        paddle.save(state, filename)

    def _load_checkpoint(self, checkpoint_path, resume):
        """
        Resume from saved checkpoints
        :param checkpoint_path: Checkpoint path to be resumed
        """
        self.logger_info("Loading checkpoint: {} ...".format(checkpoint_path))
        checkpoint = paddle.load(checkpoint_path)
        self.model.set_state_dict(checkpoint['state_dict'])
        if resume:
            self.global_step = checkpoint['global_step']
            self.start_epoch = checkpoint['epoch']
            self.config['lr_scheduler']['args']['last_epoch'] = self.start_epoch
            # self.scheduler.load_state_dict(checkpoint['scheduler'])
            self.optimizer.set_state_dict(checkpoint['optimizer'])
            if 'metrics' in checkpoint:
                self.metrics = checkpoint['metrics']
            self.logger_info("resume from checkpoint {} (epoch {})".format(
                checkpoint_path, self.start_epoch))
        else:
            self.logger_info("finetune from checkpoint {}".format(
                checkpoint_path))

    def _initialize(self, name, module, *args, **kwargs):
        module_name = self.config[name]['type']
        module_args = self.config[name].get('args', {})
        assert all([k not in module_args for k in kwargs
                    ]), 'Overwriting kwargs given in config file is not allowed'
        module_args.update(kwargs)
        return getattr(module, module_name)(*args, **module_args)

    def _initialize_scheduler(self):
        self.lr_scheduler = self._initialize('lr_scheduler',
                                             paddle.optimizer.lr)

    def _initialize_optimizer(self):
        self.optimizer = self._initialize(
            'optimizer',
            paddle.optimizer,
            parameters=self.model.parameters(),
            learning_rate=self.lr_scheduler)

    def inverse_normalize(self, batch_img):
        if self.UN_Normalize:
            batch_img[:, 0, :, :] = batch_img[:, 0, :, :] * self.normalize_std[
                0] + self.normalize_mean[0]
            batch_img[:, 1, :, :] = batch_img[:, 1, :, :] * self.normalize_std[
                1] + self.normalize_mean[1]
            batch_img[:, 2, :, :] = batch_img[:, 2, :, :] * self.normalize_std[
                2] + self.normalize_mean[2]

    def logger_info(self, s):
        if paddle.distributed.get_rank() == 0:
            self.logger.info(s)