base_runner.py 14.3 KB
Newer Older
Kai Chen's avatar
Kai Chen committed
1
# Copyright (c) Open-MMLab. All rights reserved.
Kai Chen's avatar
Kai Chen committed
2
3
import logging
import os.path as osp
4
5
import warnings
from abc import ABCMeta, abstractmethod
Kai Chen's avatar
Kai Chen committed
6
7

import torch
Harry's avatar
Harry committed
8
from torch.optim import Optimizer
Kai Chen's avatar
Kai Chen committed
9

Kai Chen's avatar
Kai Chen committed
10
import mmcv
Harry's avatar
Harry committed
11
from ..parallel import is_module_wrapper
12
from .checkpoint import load_checkpoint
Kai Chen's avatar
Kai Chen committed
13
from .dist_utils import get_dist_info
Kai Chen's avatar
Kai Chen committed
14
from .hooks import HOOKS, Hook, IterTimerHook
Kai Chen's avatar
Kai Chen committed
15
from .log_buffer import LogBuffer
Kai Chen's avatar
Kai Chen committed
16
from .priority import get_priority
17
from .utils import get_time_str
Kai Chen's avatar
Kai Chen committed
18
19


20
21
22
23
24
25
26
27
28
class BaseRunner(metaclass=ABCMeta):
    """The base class of Runner, a training helper for PyTorch.

    All subclasses should implement the following APIs:

    - ``run()``
    - ``train()``
    - ``val()``
    - ``save_checkpoint()``
Kai Chen's avatar
Kai Chen committed
29
30
31
32
33
34

    Args:
        model (:obj:`torch.nn.Module`): The model to be run.
        batch_processor (callable): A callable method that process a data
            batch. The interface of this method should be
            `batch_processor(model, data, train_mode) -> dict`
Harry's avatar
Harry committed
35
36
37
        optimizer (dict or :obj:`torch.optim.Optimizer`): It can be either an
            optimizer (in most cases) or a dict of optimizers (in models that
            requires more than one optimizer, e.g., GAN).
Kai Chen's avatar
Kai Chen committed
38
        work_dir (str, optional): The working directory to save checkpoints
39
40
            and logs. Defaults to None.
        logger (:obj:`logging.Logger`): Logger used during training.
Harry's avatar
Harry committed
41
42
             Defaults to None. (The default value is just for backward
             compatibility)
Cao Yuhang's avatar
Cao Yuhang committed
43
44
        meta (dict | None): A dict records some import information such as
            environment info and seed, which will be logged in logger hook.
45
            Defaults to None.
Kai Chen's avatar
Kai Chen committed
46
    """
Kai Chen's avatar
Kai Chen committed
47
48
49

    def __init__(self,
                 model,
50
                 batch_processor=None,
51
                 optimizer=None,
Kai Chen's avatar
Kai Chen committed
52
                 work_dir=None,
Cao Yuhang's avatar
Cao Yuhang committed
53
54
                 logger=None,
                 meta=None):
55
56
57
58
59
60
61
62
        if batch_processor is not None:
            if not callable(batch_processor):
                raise TypeError('batch_processor must be callable, '
                                f'but got {type(batch_processor)}')
            warnings.warn('batch_processor is deprecated, please implement '
                          'train_step() and val_step() in the model instead.')
            # raise an error is `batch_processor` is not None and
            # `model.train_step()` exists.
Harry's avatar
Harry committed
63
            if is_module_wrapper(model):
Kai Chen's avatar
Kai Chen committed
64
65
66
67
                _model = model.module
            else:
                _model = model
            if hasattr(_model, 'train_step') or hasattr(_model, 'val_step'):
68
69
70
                raise RuntimeError(
                    'batch_processor and model.train_step()/model.val_step() '
                    'cannot be both available.')
71
        else:
72
            assert hasattr(model, 'train_step')
Harry's avatar
Harry committed
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95

        # check the type of `optimizer`
        if isinstance(optimizer, dict):
            for name, optim in optimizer.items():
                if not isinstance(optim, Optimizer):
                    raise TypeError(
                        f'optimizer must be a dict of torch.optim.Optimizers, '
                        f'but optimizer["{name}"] is a {type(optim)}')
        elif not isinstance(optimizer, Optimizer) and optimizer is not None:
            raise TypeError(
                f'optimizer must be a torch.optim.Optimizer object '
                f'or dict or None, but got {type(optimizer)}')

        # check the type of `logger`
        if not isinstance(logger, logging.Logger):
            raise TypeError(f'logger must be a logging.Logger object, '
                            f'but got {type(logger)}')

        # check the type of `meta`
        if meta is not None and not isinstance(meta, dict):
            raise TypeError(
                f'meta must be a dict or None, but got {type(meta)}')

96
        self.model = model
Kai Chen's avatar
Kai Chen committed
97
        self.batch_processor = batch_processor
98
        self.optimizer = optimizer
Harry's avatar
Harry committed
99
100
        self.logger = logger
        self.meta = meta
Kai Chen's avatar
Kai Chen committed
101
102
103
104
105
106
107
108
109
110
111

        # create work_dir
        if mmcv.is_str(work_dir):
            self.work_dir = osp.abspath(work_dir)
            mmcv.mkdir_or_exist(self.work_dir)
        elif work_dir is None:
            self.work_dir = None
        else:
            raise TypeError('"work_dir" must be a str or None')

        # get model name from the model class
112
        if hasattr(self.model, 'module'):
Kai Chen's avatar
Kai Chen committed
113
114
115
116
            self._model_name = self.model.module.__class__.__name__
        else:
            self._model_name = self.model.__class__.__name__

117
118
        self._rank, self._world_size = get_dist_info()
        self.timestamp = get_time_str()
Kai Chen's avatar
Kai Chen committed
119
120
121
122
123
124
125
        self.mode = None
        self._hooks = []
        self._epoch = 0
        self._iter = 0
        self._inner_iter = 0
        self._max_epochs = 0
        self._max_iters = 0
126
127
        # TODO: Redesign LogBuffer, it is not flexible and elegant enough
        self.log_buffer = LogBuffer()
Kai Chen's avatar
Kai Chen committed
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174

    @property
    def model_name(self):
        """str: Name of the model, usually the module class name."""
        return self._model_name

    @property
    def rank(self):
        """int: Rank of current process. (distributed training)"""
        return self._rank

    @property
    def world_size(self):
        """int: Number of processes participating in the job.
        (distributed training)"""
        return self._world_size

    @property
    def hooks(self):
        """list[:obj:`Hook`]: A list of registered hooks."""
        return self._hooks

    @property
    def epoch(self):
        """int: Current epoch."""
        return self._epoch

    @property
    def iter(self):
        """int: Current iteration."""
        return self._iter

    @property
    def inner_iter(self):
        """int: Iteration in an epoch."""
        return self._inner_iter

    @property
    def max_epochs(self):
        """int: Maximum training epochs."""
        return self._max_epochs

    @property
    def max_iters(self):
        """int: Maximum training iterations."""
        return self._max_iters

175
176
177
    @abstractmethod
    def train(self):
        pass
Kai Chen's avatar
Kai Chen committed
178

179
180
181
    @abstractmethod
    def val(self):
        pass
Kai Chen's avatar
Kai Chen committed
182

183
184
185
    @abstractmethod
    def run(self, data_loaders, workflow, **kwargs):
        pass
Kai Chen's avatar
Kai Chen committed
186

187
188
189
190
191
192
193
194
    @abstractmethod
    def save_checkpoint(self,
                        out_dir,
                        filename_tmpl,
                        save_optimizer=True,
                        meta=None,
                        create_symlink=True):
        pass
Kai Chen's avatar
Kai Chen committed
195
196
197
198
199

    def current_lr(self):
        """Get current learning rates.

        Returns:
Harry's avatar
Harry committed
200
201
202
            list[float] | dict[str, list[float]]: Current learning rates of all
                param groups. If the runner has a dict of optimizers, this
                method will return a dict.
Kai Chen's avatar
Kai Chen committed
203
        """
Harry's avatar
Harry committed
204
205
206
207
208
209
210
        if isinstance(self.optimizer, torch.optim.Optimizer):
            lr = [group['lr'] for group in self.optimizer.param_groups]
        elif isinstance(self.optimizer, dict):
            lr = dict()
            for name, optim in self.optimizer.items():
                lr[name] = [group['lr'] for group in optim.param_groups]
        else:
211
212
            raise RuntimeError(
                'lr is not applicable because optimizer does not exist.')
Harry's avatar
Harry committed
213
        return lr
Kai Chen's avatar
Kai Chen committed
214

Wenwei Zhang's avatar
Wenwei Zhang committed
215
216
217
218
    def current_momentum(self):
        """Get current momentums.

        Returns:
Harry's avatar
Harry committed
219
220
221
            list[float] | dict[str, list[float]]: Current momentums of all
                param groups. If the runner has a dict of optimizers, this
                method will return a dict.
Wenwei Zhang's avatar
Wenwei Zhang committed
222
        """
Harry's avatar
Harry committed
223
224
225
226
227
228
229
230
231
232
233
234

        def _get_momentum(optimizer):
            momentums = []
            for group in optimizer.param_groups:
                if 'momentum' in group.keys():
                    momentums.append(group['momentum'])
                elif 'betas' in group.keys():
                    momentums.append(group['betas'][0])
                else:
                    momentums.append(0)
            return momentums

Wenwei Zhang's avatar
Wenwei Zhang committed
235
236
        if self.optimizer is None:
            raise RuntimeError(
237
                'momentum is not applicable because optimizer does not exist.')
Harry's avatar
Harry committed
238
239
240
241
242
243
        elif isinstance(self.optimizer, torch.optim.Optimizer):
            momentums = _get_momentum(self.optimizer)
        elif isinstance(self.optimizer, dict):
            momentums = dict()
            for name, optim in self.optimizer.items():
                momentums[name] = _get_momentum(optim)
244
        return momentums
Wenwei Zhang's avatar
Wenwei Zhang committed
245

Kai Chen's avatar
Kai Chen committed
246
    def register_hook(self, hook, priority='NORMAL'):
Kai Chen's avatar
Kai Chen committed
247
248
        """Register a hook into the hook list.

249
        The hook will be inserted into a priority queue, with the specified
Kai Chen's avatar
Kai Chen committed
250
        priority (See :class:`Priority` for details of priorities).
251
252
253
        For hooks with the same priority, they will be triggered in the same
        order as they are registered.

Kai Chen's avatar
Kai Chen committed
254
255
        Args:
            hook (:obj:`Hook`): The hook to be registered.
Kai Chen's avatar
Kai Chen committed
256
257
            priority (int or str or :obj:`Priority`): Hook priority.
                Lower value means higher priority.
Kai Chen's avatar
Kai Chen committed
258
259
260
261
        """
        assert isinstance(hook, Hook)
        if hasattr(hook, 'priority'):
            raise ValueError('"priority" is a reserved attribute for hooks')
Kai Chen's avatar
Kai Chen committed
262
        priority = get_priority(priority)
Kai Chen's avatar
Kai Chen committed
263
264
265
266
267
268
269
270
271
272
273
274
        hook.priority = priority
        # insert the hook to a sorted list
        inserted = False
        for i in range(len(self._hooks) - 1, -1, -1):
            if priority >= self._hooks[i].priority:
                self._hooks.insert(i + 1, hook)
                inserted = True
                break
        if not inserted:
            self._hooks.insert(0, hook)

    def call_hook(self, fn_name):
275
276
277
278
279
280
        """Call all hooks.

        Args:
            fn_name (str): The function name in each hook to be called, such as
                "before_train_epoch".
        """
Kai Chen's avatar
Kai Chen committed
281
282
283
284
285
286
287
288
        for hook in self._hooks:
            getattr(hook, fn_name)(self)

    def load_checkpoint(self, filename, map_location='cpu', strict=False):
        self.logger.info('load checkpoint from %s', filename)
        return load_checkpoint(self.model, filename, map_location, strict,
                               self.logger)

289
290
291
    def resume(self,
               checkpoint,
               resume_optimizer=True,
Kai Chen's avatar
Kai Chen committed
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
               map_location='default'):
        if map_location == 'default':
            device_id = torch.cuda.current_device()
            checkpoint = self.load_checkpoint(
                checkpoint,
                map_location=lambda storage, loc: storage.cuda(device_id))
        else:
            checkpoint = self.load_checkpoint(
                checkpoint, map_location=map_location)

        self._epoch = checkpoint['meta']['epoch']
        self._iter = checkpoint['meta']['iter']
        if 'optimizer' in checkpoint and resume_optimizer:
            self.optimizer.load_state_dict(checkpoint['optimizer'])

        self.logger.info('resumed epoch %d, iter %d', self.epoch, self.iter)

Kai Chen's avatar
Kai Chen committed
309
310
    def register_lr_hook(self, lr_config):
        if isinstance(lr_config, dict):
Kai Chen's avatar
Kai Chen committed
311
            assert 'policy' in lr_config
312
313
314
            policy_type = lr_config.pop('policy')
            # If the type of policy is all in lower case, e.g., 'cyclic',
            # then its first letter will be capitalized, e.g., to be 'Cyclic'.
Ye Liu's avatar
Ye Liu committed
315
            # This is for the convenient usage of Lr updater.
Yawei Li's avatar
Yawei Li committed
316
317
            # Since this is not applicable for `
            # CosineAnnealingLrUpdater`,
318
319
320
321
            # the string will not be changed if it contains capital letters.
            if policy_type == policy_type.lower():
                policy_type = policy_type.title()
            hook_type = policy_type + 'LrUpdaterHook'
Kai Chen's avatar
Kai Chen committed
322
323
324
325
326
327
            lr_config['type'] = hook_type
            hook = mmcv.build_from_cfg(lr_config, HOOKS)
        else:
            hook = lr_config
        self.register_hook(hook)

328
    def register_momentum_hook(self, momentum_config):
Wenwei Zhang's avatar
Wenwei Zhang committed
329
330
331
332
        if momentum_config is None:
            return
        if isinstance(momentum_config, dict):
            assert 'policy' in momentum_config
333
334
335
336
            policy_type = momentum_config.pop('policy')
            # If the type of policy is all in lower case, e.g., 'cyclic',
            # then its first letter will be capitalized, e.g., to be 'Cyclic'.
            # This is for the convenient usage of momentum updater.
Yawei Li's avatar
Yawei Li committed
337
338
            # Since this is not applicable for
            # `CosineAnnealingMomentumUpdater`,
339
340
341
342
            # the string will not be changed if it contains capital letters.
            if policy_type == policy_type.lower():
                policy_type = policy_type.title()
            hook_type = policy_type + 'MomentumUpdaterHook'
Wenwei Zhang's avatar
Wenwei Zhang committed
343
344
345
346
347
348
            momentum_config['type'] = hook_type
            hook = mmcv.build_from_cfg(momentum_config, HOOKS)
        else:
            hook = momentum_config
        self.register_hook(hook)

349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
    def register_optimizer_hook(self, optimizer_config):
        if optimizer_config is None:
            return
        if isinstance(optimizer_config, dict):
            optimizer_config.setdefault('type', 'OptimizerHook')
            hook = mmcv.build_from_cfg(optimizer_config, HOOKS)
        else:
            hook = optimizer_config
        self.register_hook(hook)

    def register_checkpoint_hook(self, checkpoint_config):
        if checkpoint_config is None:
            return
        if isinstance(checkpoint_config, dict):
            checkpoint_config.setdefault('type', 'CheckpointHook')
            hook = mmcv.build_from_cfg(checkpoint_config, HOOKS)
        else:
            hook = checkpoint_config
        self.register_hook(hook)

Kai Chen's avatar
Kai Chen committed
369
370
371
    def register_logger_hooks(self, log_config):
        log_interval = log_config['interval']
        for info in log_config['hooks']:
Kai Chen's avatar
Kai Chen committed
372
373
            logger_hook = mmcv.build_from_cfg(
                info, HOOKS, default_args=dict(interval=log_interval))
Kai Chen's avatar
Kai Chen committed
374
            self.register_hook(logger_hook, priority='VERY_LOW')
Kai Chen's avatar
Kai Chen committed
375

376
377
    def register_training_hooks(self,
                                lr_config,
pangjm's avatar
pangjm committed
378
                                optimizer_config=None,
379
                                checkpoint_config=None,
Wenwei Zhang's avatar
Wenwei Zhang committed
380
381
                                log_config=None,
                                momentum_config=None):
382
        """Register default hooks for training.
Kai Chen's avatar
Kai Chen committed
383
384

        Default hooks include:
Kai Chen's avatar
Kai Chen committed
385

Kai Chen's avatar
Kai Chen committed
386
        - LrUpdaterHook
Wenwei Zhang's avatar
Wenwei Zhang committed
387
        - MomentumUpdaterHook
Kai Chen's avatar
Kai Chen committed
388
389
390
        - OptimizerStepperHook
        - CheckpointSaverHook
        - IterTimerHook
Kai Chen's avatar
Kai Chen committed
391
        - LoggerHook(s)
Kai Chen's avatar
Kai Chen committed
392
        """
Kai Chen's avatar
Kai Chen committed
393
        self.register_lr_hook(lr_config)
394
        self.register_momentum_hook(momentum_config)
Kai Chen's avatar
Kai Chen committed
395
396
        self.register_optimizer_hook(optimizer_config)
        self.register_checkpoint_hook(checkpoint_config)
Kai Chen's avatar
Kai Chen committed
397
        self.register_hook(IterTimerHook())
Kai Chen's avatar
Kai Chen committed
398
        self.register_logger_hooks(log_config)