base_runner.py 17.9 KB
Newer Older
Kai Chen's avatar
Kai Chen committed
1
# Copyright (c) Open-MMLab. All rights reserved.
2
import copy
Kai Chen's avatar
Kai Chen committed
3
4
import logging
import os.path as osp
5
6
import warnings
from abc import ABCMeta, abstractmethod
Kai Chen's avatar
Kai Chen committed
7
8

import torch
Harry's avatar
Harry committed
9
from torch.optim import Optimizer
Kai Chen's avatar
Kai Chen committed
10

Kai Chen's avatar
Kai Chen committed
11
import mmcv
Harry's avatar
Harry committed
12
from ..parallel import is_module_wrapper
13
from .checkpoint import load_checkpoint
Kai Chen's avatar
Kai Chen committed
14
from .dist_utils import get_dist_info
15
from .hooks import HOOKS, Hook
Kai Chen's avatar
Kai Chen committed
16
from .log_buffer import LogBuffer
Kai Chen's avatar
Kai Chen committed
17
from .priority import get_priority
18
from .utils import get_time_str
Kai Chen's avatar
Kai Chen committed
19
20


21
22
23
24
25
26
27
28
29
class BaseRunner(metaclass=ABCMeta):
    """The base class of Runner, a training helper for PyTorch.

    All subclasses should implement the following APIs:

    - ``run()``
    - ``train()``
    - ``val()``
    - ``save_checkpoint()``
Kai Chen's avatar
Kai Chen committed
30
31
32
33
34
35

    Args:
        model (:obj:`torch.nn.Module`): The model to be run.
        batch_processor (callable): A callable method that process a data
            batch. The interface of this method should be
            `batch_processor(model, data, train_mode) -> dict`
Harry's avatar
Harry committed
36
37
38
        optimizer (dict or :obj:`torch.optim.Optimizer`): It can be either an
            optimizer (in most cases) or a dict of optimizers (in models that
            requires more than one optimizer, e.g., GAN).
Kai Chen's avatar
Kai Chen committed
39
        work_dir (str, optional): The working directory to save checkpoints
40
41
            and logs. Defaults to None.
        logger (:obj:`logging.Logger`): Logger used during training.
Harry's avatar
Harry committed
42
43
             Defaults to None. (The default value is just for backward
             compatibility)
Cao Yuhang's avatar
Cao Yuhang committed
44
45
        meta (dict | None): A dict records some import information such as
            environment info and seed, which will be logged in logger hook.
46
            Defaults to None.
47
48
        max_epochs (int, optional): Total training epochs.
        max_iters (int, optional): Total training iterations.
Kai Chen's avatar
Kai Chen committed
49
    """
Kai Chen's avatar
Kai Chen committed
50
51
52

    def __init__(self,
                 model,
53
                 batch_processor=None,
54
                 optimizer=None,
Kai Chen's avatar
Kai Chen committed
55
                 work_dir=None,
Cao Yuhang's avatar
Cao Yuhang committed
56
                 logger=None,
57
58
59
                 meta=None,
                 max_iters=None,
                 max_epochs=None):
60
61
62
63
64
65
66
67
        if batch_processor is not None:
            if not callable(batch_processor):
                raise TypeError('batch_processor must be callable, '
                                f'but got {type(batch_processor)}')
            warnings.warn('batch_processor is deprecated, please implement '
                          'train_step() and val_step() in the model instead.')
            # raise an error is `batch_processor` is not None and
            # `model.train_step()` exists.
Harry's avatar
Harry committed
68
            if is_module_wrapper(model):
Kai Chen's avatar
Kai Chen committed
69
70
71
72
                _model = model.module
            else:
                _model = model
            if hasattr(_model, 'train_step') or hasattr(_model, 'val_step'):
73
74
75
                raise RuntimeError(
                    'batch_processor and model.train_step()/model.val_step() '
                    'cannot be both available.')
76
        else:
77
            assert hasattr(model, 'train_step')
Harry's avatar
Harry committed
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100

        # check the type of `optimizer`
        if isinstance(optimizer, dict):
            for name, optim in optimizer.items():
                if not isinstance(optim, Optimizer):
                    raise TypeError(
                        f'optimizer must be a dict of torch.optim.Optimizers, '
                        f'but optimizer["{name}"] is a {type(optim)}')
        elif not isinstance(optimizer, Optimizer) and optimizer is not None:
            raise TypeError(
                f'optimizer must be a torch.optim.Optimizer object '
                f'or dict or None, but got {type(optimizer)}')

        # check the type of `logger`
        if not isinstance(logger, logging.Logger):
            raise TypeError(f'logger must be a logging.Logger object, '
                            f'but got {type(logger)}')

        # check the type of `meta`
        if meta is not None and not isinstance(meta, dict):
            raise TypeError(
                f'meta must be a dict or None, but got {type(meta)}')

101
        self.model = model
Kai Chen's avatar
Kai Chen committed
102
        self.batch_processor = batch_processor
103
        self.optimizer = optimizer
Harry's avatar
Harry committed
104
105
        self.logger = logger
        self.meta = meta
Kai Chen's avatar
Kai Chen committed
106
107
108
109
110
111
112
113
114
115
        # create work_dir
        if mmcv.is_str(work_dir):
            self.work_dir = osp.abspath(work_dir)
            mmcv.mkdir_or_exist(self.work_dir)
        elif work_dir is None:
            self.work_dir = None
        else:
            raise TypeError('"work_dir" must be a str or None')

        # get model name from the model class
116
        if hasattr(self.model, 'module'):
Kai Chen's avatar
Kai Chen committed
117
118
119
120
            self._model_name = self.model.module.__class__.__name__
        else:
            self._model_name = self.model.__class__.__name__

121
122
        self._rank, self._world_size = get_dist_info()
        self.timestamp = get_time_str()
Kai Chen's avatar
Kai Chen committed
123
124
125
126
127
        self.mode = None
        self._hooks = []
        self._epoch = 0
        self._iter = 0
        self._inner_iter = 0
128
129
130
131
132
133
134

        if max_epochs is not None and max_iters is not None:
            raise ValueError(
                'Only one of `max_epochs` or `max_iters` can be set.')

        self._max_epochs = max_epochs
        self._max_iters = max_iters
135
136
        # TODO: Redesign LogBuffer, it is not flexible and elegant enough
        self.log_buffer = LogBuffer()
Kai Chen's avatar
Kai Chen committed
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183

    @property
    def model_name(self):
        """str: Name of the model, usually the module class name."""
        return self._model_name

    @property
    def rank(self):
        """int: Rank of current process. (distributed training)"""
        return self._rank

    @property
    def world_size(self):
        """int: Number of processes participating in the job.
        (distributed training)"""
        return self._world_size

    @property
    def hooks(self):
        """list[:obj:`Hook`]: A list of registered hooks."""
        return self._hooks

    @property
    def epoch(self):
        """int: Current epoch."""
        return self._epoch

    @property
    def iter(self):
        """int: Current iteration."""
        return self._iter

    @property
    def inner_iter(self):
        """int: Iteration in an epoch."""
        return self._inner_iter

    @property
    def max_epochs(self):
        """int: Maximum training epochs."""
        return self._max_epochs

    @property
    def max_iters(self):
        """int: Maximum training iterations."""
        return self._max_iters

184
185
186
    @abstractmethod
    def train(self):
        pass
Kai Chen's avatar
Kai Chen committed
187

188
189
190
    @abstractmethod
    def val(self):
        pass
Kai Chen's avatar
Kai Chen committed
191

192
193
194
    @abstractmethod
    def run(self, data_loaders, workflow, **kwargs):
        pass
Kai Chen's avatar
Kai Chen committed
195

196
197
198
199
200
201
202
203
    @abstractmethod
    def save_checkpoint(self,
                        out_dir,
                        filename_tmpl,
                        save_optimizer=True,
                        meta=None,
                        create_symlink=True):
        pass
Kai Chen's avatar
Kai Chen committed
204
205
206
207
208

    def current_lr(self):
        """Get current learning rates.

        Returns:
Harry's avatar
Harry committed
209
210
211
            list[float] | dict[str, list[float]]: Current learning rates of all
                param groups. If the runner has a dict of optimizers, this
                method will return a dict.
Kai Chen's avatar
Kai Chen committed
212
        """
Harry's avatar
Harry committed
213
214
215
216
217
218
219
        if isinstance(self.optimizer, torch.optim.Optimizer):
            lr = [group['lr'] for group in self.optimizer.param_groups]
        elif isinstance(self.optimizer, dict):
            lr = dict()
            for name, optim in self.optimizer.items():
                lr[name] = [group['lr'] for group in optim.param_groups]
        else:
220
221
            raise RuntimeError(
                'lr is not applicable because optimizer does not exist.')
Harry's avatar
Harry committed
222
        return lr
Kai Chen's avatar
Kai Chen committed
223

Wenwei Zhang's avatar
Wenwei Zhang committed
224
225
226
227
    def current_momentum(self):
        """Get current momentums.

        Returns:
Harry's avatar
Harry committed
228
229
230
            list[float] | dict[str, list[float]]: Current momentums of all
                param groups. If the runner has a dict of optimizers, this
                method will return a dict.
Wenwei Zhang's avatar
Wenwei Zhang committed
231
        """
Harry's avatar
Harry committed
232
233
234
235
236
237
238
239
240
241
242
243

        def _get_momentum(optimizer):
            momentums = []
            for group in optimizer.param_groups:
                if 'momentum' in group.keys():
                    momentums.append(group['momentum'])
                elif 'betas' in group.keys():
                    momentums.append(group['betas'][0])
                else:
                    momentums.append(0)
            return momentums

Wenwei Zhang's avatar
Wenwei Zhang committed
244
245
        if self.optimizer is None:
            raise RuntimeError(
246
                'momentum is not applicable because optimizer does not exist.')
Harry's avatar
Harry committed
247
248
249
250
251
252
        elif isinstance(self.optimizer, torch.optim.Optimizer):
            momentums = _get_momentum(self.optimizer)
        elif isinstance(self.optimizer, dict):
            momentums = dict()
            for name, optim in self.optimizer.items():
                momentums[name] = _get_momentum(optim)
253
        return momentums
Wenwei Zhang's avatar
Wenwei Zhang committed
254

Kai Chen's avatar
Kai Chen committed
255
    def register_hook(self, hook, priority='NORMAL'):
Kai Chen's avatar
Kai Chen committed
256
257
        """Register a hook into the hook list.

258
        The hook will be inserted into a priority queue, with the specified
Kai Chen's avatar
Kai Chen committed
259
        priority (See :class:`Priority` for details of priorities).
260
261
262
        For hooks with the same priority, they will be triggered in the same
        order as they are registered.

Kai Chen's avatar
Kai Chen committed
263
264
        Args:
            hook (:obj:`Hook`): The hook to be registered.
Kai Chen's avatar
Kai Chen committed
265
266
            priority (int or str or :obj:`Priority`): Hook priority.
                Lower value means higher priority.
Kai Chen's avatar
Kai Chen committed
267
268
269
270
        """
        assert isinstance(hook, Hook)
        if hasattr(hook, 'priority'):
            raise ValueError('"priority" is a reserved attribute for hooks')
Kai Chen's avatar
Kai Chen committed
271
        priority = get_priority(priority)
Kai Chen's avatar
Kai Chen committed
272
273
274
275
276
277
278
279
280
281
282
        hook.priority = priority
        # insert the hook to a sorted list
        inserted = False
        for i in range(len(self._hooks) - 1, -1, -1):
            if priority >= self._hooks[i].priority:
                self._hooks.insert(i + 1, hook)
                inserted = True
                break
        if not inserted:
            self._hooks.insert(0, hook)

Wang Xinjiang's avatar
Wang Xinjiang committed
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
    def register_hook_from_cfg(self, hook_cfg):
        """Register a hook from its cfg.

        Args:
            hook_cfg (dict): Hook config. It should have at least keys 'type'
              and 'priority' indicating its type and priority.

        Notes:
            The specific hook class to register should not use 'type' and
            'priority' arguments during initialization.
        """
        hook_cfg = hook_cfg.copy()
        priority = hook_cfg.pop('priority', 'NORMAL')
        hook = mmcv.build_from_cfg(hook_cfg, HOOKS)
        self.register_hook(hook, priority=priority)

Kai Chen's avatar
Kai Chen committed
299
    def call_hook(self, fn_name):
300
301
302
303
304
305
        """Call all hooks.

        Args:
            fn_name (str): The function name in each hook to be called, such as
                "before_train_epoch".
        """
Kai Chen's avatar
Kai Chen committed
306
307
308
        for hook in self._hooks:
            getattr(hook, fn_name)(self)

309
310
311
312
313
314
    def load_checkpoint(self,
                        filename,
                        map_location='cpu',
                        strict=False,
                        revise_keys=[(r'^module.', '')]):

Kai Chen's avatar
Kai Chen committed
315
        self.logger.info('load checkpoint from %s', filename)
316
317
318
319
320
321
322
        return load_checkpoint(
            self.model,
            filename,
            map_location,
            strict,
            self.logger,
            revise_keys=revise_keys)
Kai Chen's avatar
Kai Chen committed
323

324
325
326
    def resume(self,
               checkpoint,
               resume_optimizer=True,
Kai Chen's avatar
Kai Chen committed
327
328
               map_location='default'):
        if map_location == 'default':
shilong's avatar
shilong committed
329
330
331
332
333
334
335
            if torch.cuda.is_available():
                device_id = torch.cuda.current_device()
                checkpoint = self.load_checkpoint(
                    checkpoint,
                    map_location=lambda storage, loc: storage.cuda(device_id))
            else:
                checkpoint = self.load_checkpoint(checkpoint)
Kai Chen's avatar
Kai Chen committed
336
337
338
339
340
341
        else:
            checkpoint = self.load_checkpoint(
                checkpoint, map_location=map_location)

        self._epoch = checkpoint['meta']['epoch']
        self._iter = checkpoint['meta']['iter']
342
343
344
345
346
        if self.meta is None:
            self.meta = {}
        self.meta.setdefault('hook_msgs', {})
        # load `last_ckpt`, `best_score`, `best_ckpt`, etc. for hook messages
        self.meta['hook_msgs'].update(checkpoint['meta'].get('hook_msgs', {}))
347
348
349
350
351
352
353
354
355
356
357
358
359
360

        # Re-calculate the number of iterations when resuming
        # models with different number of GPUs
        if 'config' in checkpoint['meta']:
            config = mmcv.Config.fromstring(
                checkpoint['meta']['config'], file_format='.py')
            previous_gpu_ids = config.get('gpu_ids', None)
            if previous_gpu_ids and len(previous_gpu_ids) > 0 and len(
                    previous_gpu_ids) != self.world_size:
                self._iter = int(self._iter * len(previous_gpu_ids) /
                                 self.world_size)
                self.logger.info('the iteration number is changed due to '
                                 'change of GPU number')

Kai Chen's avatar
Kai Chen committed
361
        if 'optimizer' in checkpoint and resume_optimizer:
362
363
364
365
366
367
368
369
370
371
            if isinstance(self.optimizer, Optimizer):
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            elif isinstance(self.optimizer, dict):
                for k in self.optimizer.keys():
                    self.optimizer[k].load_state_dict(
                        checkpoint['optimizer'][k])
            else:
                raise TypeError(
                    'Optimizer should be dict or torch.optim.Optimizer '
                    f'but got {type(self.optimizer)}')
Kai Chen's avatar
Kai Chen committed
372
373
374

        self.logger.info('resumed epoch %d, iter %d', self.epoch, self.iter)

Kai Chen's avatar
Kai Chen committed
375
    def register_lr_hook(self, lr_config):
376
377
378
        if lr_config is None:
            return
        elif isinstance(lr_config, dict):
Kai Chen's avatar
Kai Chen committed
379
            assert 'policy' in lr_config
380
381
382
            policy_type = lr_config.pop('policy')
            # If the type of policy is all in lower case, e.g., 'cyclic',
            # then its first letter will be capitalized, e.g., to be 'Cyclic'.
Ye Liu's avatar
Ye Liu committed
383
            # This is for the convenient usage of Lr updater.
Yawei Li's avatar
Yawei Li committed
384
385
            # Since this is not applicable for `
            # CosineAnnealingLrUpdater`,
386
387
388
389
            # the string will not be changed if it contains capital letters.
            if policy_type == policy_type.lower():
                policy_type = policy_type.title()
            hook_type = policy_type + 'LrUpdaterHook'
Kai Chen's avatar
Kai Chen committed
390
391
392
393
394
395
            lr_config['type'] = hook_type
            hook = mmcv.build_from_cfg(lr_config, HOOKS)
        else:
            hook = lr_config
        self.register_hook(hook)

396
    def register_momentum_hook(self, momentum_config):
Wenwei Zhang's avatar
Wenwei Zhang committed
397
398
399
400
        if momentum_config is None:
            return
        if isinstance(momentum_config, dict):
            assert 'policy' in momentum_config
401
402
403
404
            policy_type = momentum_config.pop('policy')
            # If the type of policy is all in lower case, e.g., 'cyclic',
            # then its first letter will be capitalized, e.g., to be 'Cyclic'.
            # This is for the convenient usage of momentum updater.
Yawei Li's avatar
Yawei Li committed
405
406
            # Since this is not applicable for
            # `CosineAnnealingMomentumUpdater`,
407
408
409
410
            # the string will not be changed if it contains capital letters.
            if policy_type == policy_type.lower():
                policy_type = policy_type.title()
            hook_type = policy_type + 'MomentumUpdaterHook'
Wenwei Zhang's avatar
Wenwei Zhang committed
411
412
413
414
415
416
            momentum_config['type'] = hook_type
            hook = mmcv.build_from_cfg(momentum_config, HOOKS)
        else:
            hook = momentum_config
        self.register_hook(hook)

417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
    def register_optimizer_hook(self, optimizer_config):
        if optimizer_config is None:
            return
        if isinstance(optimizer_config, dict):
            optimizer_config.setdefault('type', 'OptimizerHook')
            hook = mmcv.build_from_cfg(optimizer_config, HOOKS)
        else:
            hook = optimizer_config
        self.register_hook(hook)

    def register_checkpoint_hook(self, checkpoint_config):
        if checkpoint_config is None:
            return
        if isinstance(checkpoint_config, dict):
            checkpoint_config.setdefault('type', 'CheckpointHook')
            hook = mmcv.build_from_cfg(checkpoint_config, HOOKS)
        else:
            hook = checkpoint_config
        self.register_hook(hook)

Kai Chen's avatar
Kai Chen committed
437
    def register_logger_hooks(self, log_config):
su's avatar
su committed
438
439
        if log_config is None:
            return
Kai Chen's avatar
Kai Chen committed
440
441
        log_interval = log_config['interval']
        for info in log_config['hooks']:
Kai Chen's avatar
Kai Chen committed
442
443
            logger_hook = mmcv.build_from_cfg(
                info, HOOKS, default_args=dict(interval=log_interval))
Kai Chen's avatar
Kai Chen committed
444
            self.register_hook(logger_hook, priority='VERY_LOW')
Kai Chen's avatar
Kai Chen committed
445

446
447
448
449
450
    def register_timer_hook(self, timer_config):
        if timer_config is None:
            return
        if isinstance(timer_config, dict):
            timer_config_ = copy.deepcopy(timer_config)
Miao Zheng's avatar
Miao Zheng committed
451
            hook = mmcv.build_from_cfg(timer_config_, HOOKS)
452
453
454
455
        else:
            hook = timer_config
        self.register_hook(hook)

456
457
458
459
460
461
462
463
464
465
    def register_profiler_hook(self, profiler_config):
        if profiler_config is None:
            return
        if isinstance(profiler_config, dict):
            profiler_config.setdefault('type', 'ProfilerHook')
            hook = mmcv.build_from_cfg(profiler_config, HOOKS)
        else:
            hook = profiler_config
        self.register_hook(hook)

466
467
    def register_training_hooks(self,
                                lr_config,
pangjm's avatar
pangjm committed
468
                                optimizer_config=None,
469
                                checkpoint_config=None,
Wenwei Zhang's avatar
Wenwei Zhang committed
470
                                log_config=None,
471
472
                                momentum_config=None,
                                timer_config=dict(type='IterTimerHook')):
473
        """Register default hooks for training.
Kai Chen's avatar
Kai Chen committed
474
475

        Default hooks include:
Kai Chen's avatar
Kai Chen committed
476

Kai Chen's avatar
Kai Chen committed
477
        - LrUpdaterHook
Wenwei Zhang's avatar
Wenwei Zhang committed
478
        - MomentumUpdaterHook
Kai Chen's avatar
Kai Chen committed
479
480
481
        - OptimizerStepperHook
        - CheckpointSaverHook
        - IterTimerHook
Kai Chen's avatar
Kai Chen committed
482
        - LoggerHook(s)
Kai Chen's avatar
Kai Chen committed
483
        """
Kai Chen's avatar
Kai Chen committed
484
        self.register_lr_hook(lr_config)
485
        self.register_momentum_hook(momentum_config)
Kai Chen's avatar
Kai Chen committed
486
487
        self.register_optimizer_hook(optimizer_config)
        self.register_checkpoint_hook(checkpoint_config)
488
        self.register_timer_hook(timer_config)
Kai Chen's avatar
Kai Chen committed
489
        self.register_logger_hooks(log_config)