base_runner.py 20.3 KB
Newer Older
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import copy
Kai Chen's avatar
Kai Chen committed
3
4
import logging
import os.path as osp
5
6
import warnings
from abc import ABCMeta, abstractmethod
Kai Chen's avatar
Kai Chen committed
7
8

import torch
Harry's avatar
Harry committed
9
from torch.optim import Optimizer
Kai Chen's avatar
Kai Chen committed
10

Kai Chen's avatar
Kai Chen committed
11
import mmcv
Harry's avatar
Harry committed
12
from ..parallel import is_module_wrapper
13
from .checkpoint import load_checkpoint
Kai Chen's avatar
Kai Chen committed
14
from .dist_utils import get_dist_info
15
from .hooks import HOOKS, Hook
Kai Chen's avatar
Kai Chen committed
16
from .log_buffer import LogBuffer
17
from .priority import Priority, get_priority
18
from .utils import get_time_str
Kai Chen's avatar
Kai Chen committed
19
20


21
22
23
24
25
26
27
28
29
class BaseRunner(metaclass=ABCMeta):
    """The base class of Runner, a training helper for PyTorch.

    All subclasses should implement the following APIs:

    - ``run()``
    - ``train()``
    - ``val()``
    - ``save_checkpoint()``
Kai Chen's avatar
Kai Chen committed
30
31
32
33
34
35

    Args:
        model (:obj:`torch.nn.Module`): The model to be run.
        batch_processor (callable): A callable method that process a data
            batch. The interface of this method should be
            `batch_processor(model, data, train_mode) -> dict`
Harry's avatar
Harry committed
36
37
38
        optimizer (dict or :obj:`torch.optim.Optimizer`): It can be either an
            optimizer (in most cases) or a dict of optimizers (in models that
            requires more than one optimizer, e.g., GAN).
Kai Chen's avatar
Kai Chen committed
39
        work_dir (str, optional): The working directory to save checkpoints
40
41
            and logs. Defaults to None.
        logger (:obj:`logging.Logger`): Logger used during training.
Harry's avatar
Harry committed
42
43
             Defaults to None. (The default value is just for backward
             compatibility)
Cao Yuhang's avatar
Cao Yuhang committed
44
45
        meta (dict | None): A dict records some import information such as
            environment info and seed, which will be logged in logger hook.
46
            Defaults to None.
47
48
        max_epochs (int, optional): Total training epochs.
        max_iters (int, optional): Total training iterations.
Kai Chen's avatar
Kai Chen committed
49
    """
Kai Chen's avatar
Kai Chen committed
50
51
52

    def __init__(self,
                 model,
53
                 batch_processor=None,
54
                 optimizer=None,
Kai Chen's avatar
Kai Chen committed
55
                 work_dir=None,
Cao Yuhang's avatar
Cao Yuhang committed
56
                 logger=None,
57
58
59
                 meta=None,
                 max_iters=None,
                 max_epochs=None):
60
61
62
63
64
65
66
67
        if batch_processor is not None:
            if not callable(batch_processor):
                raise TypeError('batch_processor must be callable, '
                                f'but got {type(batch_processor)}')
            warnings.warn('batch_processor is deprecated, please implement '
                          'train_step() and val_step() in the model instead.')
            # raise an error is `batch_processor` is not None and
            # `model.train_step()` exists.
Harry's avatar
Harry committed
68
            if is_module_wrapper(model):
Kai Chen's avatar
Kai Chen committed
69
70
71
72
                _model = model.module
            else:
                _model = model
            if hasattr(_model, 'train_step') or hasattr(_model, 'val_step'):
73
74
75
                raise RuntimeError(
                    'batch_processor and model.train_step()/model.val_step() '
                    'cannot be both available.')
76
        else:
77
            assert hasattr(model, 'train_step')
Harry's avatar
Harry committed
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100

        # check the type of `optimizer`
        if isinstance(optimizer, dict):
            for name, optim in optimizer.items():
                if not isinstance(optim, Optimizer):
                    raise TypeError(
                        f'optimizer must be a dict of torch.optim.Optimizers, '
                        f'but optimizer["{name}"] is a {type(optim)}')
        elif not isinstance(optimizer, Optimizer) and optimizer is not None:
            raise TypeError(
                f'optimizer must be a torch.optim.Optimizer object '
                f'or dict or None, but got {type(optimizer)}')

        # check the type of `logger`
        if not isinstance(logger, logging.Logger):
            raise TypeError(f'logger must be a logging.Logger object, '
                            f'but got {type(logger)}')

        # check the type of `meta`
        if meta is not None and not isinstance(meta, dict):
            raise TypeError(
                f'meta must be a dict or None, but got {type(meta)}')

101
        self.model = model
Kai Chen's avatar
Kai Chen committed
102
        self.batch_processor = batch_processor
103
        self.optimizer = optimizer
Harry's avatar
Harry committed
104
105
        self.logger = logger
        self.meta = meta
Kai Chen's avatar
Kai Chen committed
106
107
108
109
110
111
112
113
114
115
        # create work_dir
        if mmcv.is_str(work_dir):
            self.work_dir = osp.abspath(work_dir)
            mmcv.mkdir_or_exist(self.work_dir)
        elif work_dir is None:
            self.work_dir = None
        else:
            raise TypeError('"work_dir" must be a str or None')

        # get model name from the model class
116
        if hasattr(self.model, 'module'):
Kai Chen's avatar
Kai Chen committed
117
118
119
120
            self._model_name = self.model.module.__class__.__name__
        else:
            self._model_name = self.model.__class__.__name__

121
122
        self._rank, self._world_size = get_dist_info()
        self.timestamp = get_time_str()
Kai Chen's avatar
Kai Chen committed
123
124
125
126
127
        self.mode = None
        self._hooks = []
        self._epoch = 0
        self._iter = 0
        self._inner_iter = 0
128
129
130
131
132
133
134

        if max_epochs is not None and max_iters is not None:
            raise ValueError(
                'Only one of `max_epochs` or `max_iters` can be set.')

        self._max_epochs = max_epochs
        self._max_iters = max_iters
135
136
        # TODO: Redesign LogBuffer, it is not flexible and elegant enough
        self.log_buffer = LogBuffer()
Kai Chen's avatar
Kai Chen committed
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183

    @property
    def model_name(self):
        """str: Name of the model, usually the module class name."""
        return self._model_name

    @property
    def rank(self):
        """int: Rank of current process. (distributed training)"""
        return self._rank

    @property
    def world_size(self):
        """int: Number of processes participating in the job.
        (distributed training)"""
        return self._world_size

    @property
    def hooks(self):
        """list[:obj:`Hook`]: A list of registered hooks."""
        return self._hooks

    @property
    def epoch(self):
        """int: Current epoch."""
        return self._epoch

    @property
    def iter(self):
        """int: Current iteration."""
        return self._iter

    @property
    def inner_iter(self):
        """int: Iteration in an epoch."""
        return self._inner_iter

    @property
    def max_epochs(self):
        """int: Maximum training epochs."""
        return self._max_epochs

    @property
    def max_iters(self):
        """int: Maximum training iterations."""
        return self._max_iters

184
185
186
    @abstractmethod
    def train(self):
        pass
Kai Chen's avatar
Kai Chen committed
187

188
189
190
    @abstractmethod
    def val(self):
        pass
Kai Chen's avatar
Kai Chen committed
191

192
193
194
    @abstractmethod
    def run(self, data_loaders, workflow, **kwargs):
        pass
Kai Chen's avatar
Kai Chen committed
195

196
197
198
199
200
201
202
203
    @abstractmethod
    def save_checkpoint(self,
                        out_dir,
                        filename_tmpl,
                        save_optimizer=True,
                        meta=None,
                        create_symlink=True):
        pass
Kai Chen's avatar
Kai Chen committed
204
205
206
207
208

    def current_lr(self):
        """Get current learning rates.

        Returns:
Harry's avatar
Harry committed
209
210
211
            list[float] | dict[str, list[float]]: Current learning rates of all
                param groups. If the runner has a dict of optimizers, this
                method will return a dict.
Kai Chen's avatar
Kai Chen committed
212
        """
Harry's avatar
Harry committed
213
214
215
216
217
218
219
        if isinstance(self.optimizer, torch.optim.Optimizer):
            lr = [group['lr'] for group in self.optimizer.param_groups]
        elif isinstance(self.optimizer, dict):
            lr = dict()
            for name, optim in self.optimizer.items():
                lr[name] = [group['lr'] for group in optim.param_groups]
        else:
220
221
            raise RuntimeError(
                'lr is not applicable because optimizer does not exist.')
Harry's avatar
Harry committed
222
        return lr
Kai Chen's avatar
Kai Chen committed
223

Wenwei Zhang's avatar
Wenwei Zhang committed
224
225
226
227
    def current_momentum(self):
        """Get current momentums.

        Returns:
Harry's avatar
Harry committed
228
229
230
            list[float] | dict[str, list[float]]: Current momentums of all
                param groups. If the runner has a dict of optimizers, this
                method will return a dict.
Wenwei Zhang's avatar
Wenwei Zhang committed
231
        """
Harry's avatar
Harry committed
232
233
234
235
236
237
238
239
240
241
242
243

        def _get_momentum(optimizer):
            momentums = []
            for group in optimizer.param_groups:
                if 'momentum' in group.keys():
                    momentums.append(group['momentum'])
                elif 'betas' in group.keys():
                    momentums.append(group['betas'][0])
                else:
                    momentums.append(0)
            return momentums

Wenwei Zhang's avatar
Wenwei Zhang committed
244
245
        if self.optimizer is None:
            raise RuntimeError(
246
                'momentum is not applicable because optimizer does not exist.')
Harry's avatar
Harry committed
247
248
249
250
251
252
        elif isinstance(self.optimizer, torch.optim.Optimizer):
            momentums = _get_momentum(self.optimizer)
        elif isinstance(self.optimizer, dict):
            momentums = dict()
            for name, optim in self.optimizer.items():
                momentums[name] = _get_momentum(optim)
253
        return momentums
Wenwei Zhang's avatar
Wenwei Zhang committed
254

Kai Chen's avatar
Kai Chen committed
255
    def register_hook(self, hook, priority='NORMAL'):
Kai Chen's avatar
Kai Chen committed
256
257
        """Register a hook into the hook list.

258
        The hook will be inserted into a priority queue, with the specified
Kai Chen's avatar
Kai Chen committed
259
        priority (See :class:`Priority` for details of priorities).
260
261
262
        For hooks with the same priority, they will be triggered in the same
        order as they are registered.

Kai Chen's avatar
Kai Chen committed
263
264
        Args:
            hook (:obj:`Hook`): The hook to be registered.
Kai Chen's avatar
Kai Chen committed
265
266
            priority (int or str or :obj:`Priority`): Hook priority.
                Lower value means higher priority.
Kai Chen's avatar
Kai Chen committed
267
268
269
270
        """
        assert isinstance(hook, Hook)
        if hasattr(hook, 'priority'):
            raise ValueError('"priority" is a reserved attribute for hooks')
Kai Chen's avatar
Kai Chen committed
271
        priority = get_priority(priority)
Kai Chen's avatar
Kai Chen committed
272
273
274
275
276
277
278
279
280
281
282
        hook.priority = priority
        # insert the hook to a sorted list
        inserted = False
        for i in range(len(self._hooks) - 1, -1, -1):
            if priority >= self._hooks[i].priority:
                self._hooks.insert(i + 1, hook)
                inserted = True
                break
        if not inserted:
            self._hooks.insert(0, hook)

Wang Xinjiang's avatar
Wang Xinjiang committed
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
    def register_hook_from_cfg(self, hook_cfg):
        """Register a hook from its cfg.

        Args:
            hook_cfg (dict): Hook config. It should have at least keys 'type'
              and 'priority' indicating its type and priority.

        Notes:
            The specific hook class to register should not use 'type' and
            'priority' arguments during initialization.
        """
        hook_cfg = hook_cfg.copy()
        priority = hook_cfg.pop('priority', 'NORMAL')
        hook = mmcv.build_from_cfg(hook_cfg, HOOKS)
        self.register_hook(hook, priority=priority)

Kai Chen's avatar
Kai Chen committed
299
    def call_hook(self, fn_name):
300
301
302
303
304
305
        """Call all hooks.

        Args:
            fn_name (str): The function name in each hook to be called, such as
                "before_train_epoch".
        """
Kai Chen's avatar
Kai Chen committed
306
307
308
        for hook in self._hooks:
            getattr(hook, fn_name)(self)

309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
    def get_hook_info(self):
        # Get hooks info in each stage
        stage_hook_map = {stage: [] for stage in Hook.stages}
        for hook in self.hooks:
            try:
                priority = Priority(hook.priority).name
            except ValueError:
                priority = hook.priority
            classname = hook.__class__.__name__
            hook_info = f'({priority:<12}) {classname:<35}'
            for trigger_stage in hook.get_triggered_stages():
                stage_hook_map[trigger_stage].append(hook_info)

        stage_hook_infos = []
        for stage in Hook.stages:
            hook_infos = stage_hook_map[stage]
            if len(hook_infos) > 0:
                info = f'{stage}:\n'
                info += '\n'.join(hook_infos)
                info += '\n -------------------- '
                stage_hook_infos.append(info)
        return '\n'.join(stage_hook_infos)

332
333
334
335
336
337
338
339
340
341
342
343
    def load_checkpoint(self,
                        filename,
                        map_location='cpu',
                        strict=False,
                        revise_keys=[(r'^module.', '')]):
        return load_checkpoint(
            self.model,
            filename,
            map_location,
            strict,
            self.logger,
            revise_keys=revise_keys)
Kai Chen's avatar
Kai Chen committed
344

345
346
347
    def resume(self,
               checkpoint,
               resume_optimizer=True,
Kai Chen's avatar
Kai Chen committed
348
349
               map_location='default'):
        if map_location == 'default':
shilong's avatar
shilong committed
350
351
352
353
354
355
356
            if torch.cuda.is_available():
                device_id = torch.cuda.current_device()
                checkpoint = self.load_checkpoint(
                    checkpoint,
                    map_location=lambda storage, loc: storage.cuda(device_id))
            else:
                checkpoint = self.load_checkpoint(checkpoint)
Kai Chen's avatar
Kai Chen committed
357
358
359
360
361
362
        else:
            checkpoint = self.load_checkpoint(
                checkpoint, map_location=map_location)

        self._epoch = checkpoint['meta']['epoch']
        self._iter = checkpoint['meta']['iter']
363
364
365
366
367
        if self.meta is None:
            self.meta = {}
        self.meta.setdefault('hook_msgs', {})
        # load `last_ckpt`, `best_score`, `best_ckpt`, etc. for hook messages
        self.meta['hook_msgs'].update(checkpoint['meta'].get('hook_msgs', {}))
368
369
370
371
372
373
374
375
376
377
378
379
380
381

        # Re-calculate the number of iterations when resuming
        # models with different number of GPUs
        if 'config' in checkpoint['meta']:
            config = mmcv.Config.fromstring(
                checkpoint['meta']['config'], file_format='.py')
            previous_gpu_ids = config.get('gpu_ids', None)
            if previous_gpu_ids and len(previous_gpu_ids) > 0 and len(
                    previous_gpu_ids) != self.world_size:
                self._iter = int(self._iter * len(previous_gpu_ids) /
                                 self.world_size)
                self.logger.info('the iteration number is changed due to '
                                 'change of GPU number')

382
383
384
        # resume meta information meta
        self.meta = checkpoint['meta']

Kai Chen's avatar
Kai Chen committed
385
        if 'optimizer' in checkpoint and resume_optimizer:
386
387
388
389
390
391
392
393
394
395
            if isinstance(self.optimizer, Optimizer):
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            elif isinstance(self.optimizer, dict):
                for k in self.optimizer.keys():
                    self.optimizer[k].load_state_dict(
                        checkpoint['optimizer'][k])
            else:
                raise TypeError(
                    'Optimizer should be dict or torch.optim.Optimizer '
                    f'but got {type(self.optimizer)}')
Kai Chen's avatar
Kai Chen committed
396
397
398

        self.logger.info('resumed epoch %d, iter %d', self.epoch, self.iter)

Kai Chen's avatar
Kai Chen committed
399
    def register_lr_hook(self, lr_config):
400
401
402
        if lr_config is None:
            return
        elif isinstance(lr_config, dict):
Kai Chen's avatar
Kai Chen committed
403
            assert 'policy' in lr_config
404
405
406
            policy_type = lr_config.pop('policy')
            # If the type of policy is all in lower case, e.g., 'cyclic',
            # then its first letter will be capitalized, e.g., to be 'Cyclic'.
Ye Liu's avatar
Ye Liu committed
407
            # This is for the convenient usage of Lr updater.
Yawei Li's avatar
Yawei Li committed
408
409
            # Since this is not applicable for `
            # CosineAnnealingLrUpdater`,
410
411
412
413
            # the string will not be changed if it contains capital letters.
            if policy_type == policy_type.lower():
                policy_type = policy_type.title()
            hook_type = policy_type + 'LrUpdaterHook'
Kai Chen's avatar
Kai Chen committed
414
415
416
417
            lr_config['type'] = hook_type
            hook = mmcv.build_from_cfg(lr_config, HOOKS)
        else:
            hook = lr_config
418
        self.register_hook(hook, priority='VERY_HIGH')
Kai Chen's avatar
Kai Chen committed
419

420
    def register_momentum_hook(self, momentum_config):
Wenwei Zhang's avatar
Wenwei Zhang committed
421
422
423
424
        if momentum_config is None:
            return
        if isinstance(momentum_config, dict):
            assert 'policy' in momentum_config
425
426
427
428
            policy_type = momentum_config.pop('policy')
            # If the type of policy is all in lower case, e.g., 'cyclic',
            # then its first letter will be capitalized, e.g., to be 'Cyclic'.
            # This is for the convenient usage of momentum updater.
Yawei Li's avatar
Yawei Li committed
429
430
            # Since this is not applicable for
            # `CosineAnnealingMomentumUpdater`,
431
432
433
434
            # the string will not be changed if it contains capital letters.
            if policy_type == policy_type.lower():
                policy_type = policy_type.title()
            hook_type = policy_type + 'MomentumUpdaterHook'
Wenwei Zhang's avatar
Wenwei Zhang committed
435
436
437
438
            momentum_config['type'] = hook_type
            hook = mmcv.build_from_cfg(momentum_config, HOOKS)
        else:
            hook = momentum_config
439
        self.register_hook(hook, priority='HIGH')
Wenwei Zhang's avatar
Wenwei Zhang committed
440

441
442
443
444
445
446
447
448
    def register_optimizer_hook(self, optimizer_config):
        if optimizer_config is None:
            return
        if isinstance(optimizer_config, dict):
            optimizer_config.setdefault('type', 'OptimizerHook')
            hook = mmcv.build_from_cfg(optimizer_config, HOOKS)
        else:
            hook = optimizer_config
449
        self.register_hook(hook, priority='ABOVE_NORMAL')
450
451
452
453
454
455
456
457
458

    def register_checkpoint_hook(self, checkpoint_config):
        if checkpoint_config is None:
            return
        if isinstance(checkpoint_config, dict):
            checkpoint_config.setdefault('type', 'CheckpointHook')
            hook = mmcv.build_from_cfg(checkpoint_config, HOOKS)
        else:
            hook = checkpoint_config
459
        self.register_hook(hook, priority='NORMAL')
460

Kai Chen's avatar
Kai Chen committed
461
    def register_logger_hooks(self, log_config):
su's avatar
su committed
462
463
        if log_config is None:
            return
Kai Chen's avatar
Kai Chen committed
464
465
        log_interval = log_config['interval']
        for info in log_config['hooks']:
Kai Chen's avatar
Kai Chen committed
466
467
            logger_hook = mmcv.build_from_cfg(
                info, HOOKS, default_args=dict(interval=log_interval))
468
            self.register_hook(logger_hook, priority='VERY_LOW')
Kai Chen's avatar
Kai Chen committed
469

470
471
472
473
474
    def register_timer_hook(self, timer_config):
        if timer_config is None:
            return
        if isinstance(timer_config, dict):
            timer_config_ = copy.deepcopy(timer_config)
Miao Zheng's avatar
Miao Zheng committed
475
            hook = mmcv.build_from_cfg(timer_config_, HOOKS)
476
477
        else:
            hook = timer_config
478
        self.register_hook(hook, priority='LOW')
479
480
481
482
483
484
485
486
487
488
489
490
491

    def register_custom_hooks(self, custom_config):
        if custom_config is None:
            return

        if not isinstance(custom_config, list):
            custom_config = [custom_config]

        for item in custom_config:
            if isinstance(item, dict):
                self.register_hook_from_cfg(item)
            else:
                self.register_hook(item, priority='NORMAL')
492

493
494
495
496
497
498
499
500
501
502
    def register_profiler_hook(self, profiler_config):
        if profiler_config is None:
            return
        if isinstance(profiler_config, dict):
            profiler_config.setdefault('type', 'ProfilerHook')
            hook = mmcv.build_from_cfg(profiler_config, HOOKS)
        else:
            hook = profiler_config
        self.register_hook(hook)

503
504
    def register_training_hooks(self,
                                lr_config,
pangjm's avatar
pangjm committed
505
                                optimizer_config=None,
506
                                checkpoint_config=None,
Wenwei Zhang's avatar
Wenwei Zhang committed
507
                                log_config=None,
508
                                momentum_config=None,
509
510
511
512
513
514
                                timer_config=dict(type='IterTimerHook'),
                                custom_hooks_config=None):
        """Register default and custom hooks for training.

        Default and custom hooks include:

515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
        +----------------------+-------------------------+
        | Hooks                | Priority                |
        +======================+=========================+
        | LrUpdaterHook        | VERY_HIGH (10)          |
        +----------------------+-------------------------+
        | MomentumUpdaterHook  | HIGH (30)               |
        +----------------------+-------------------------+
        | OptimizerStepperHook | ABOVE_NORMAL (40)       |
        +----------------------+-------------------------+
        | CheckpointSaverHook  | NORMAL (50)             |
        +----------------------+-------------------------+
        | IterTimerHook        | LOW (70)                |
        +----------------------+-------------------------+
        | LoggerHook(s)        | VERY_LOW (90)           |
        +----------------------+-------------------------+
        | CustomHook(s)        | defaults to NORMAL (50) |
        +----------------------+-------------------------+

        If custom hooks have same priority with default hooks, custom hooks
        will be triggered after default hooks.
Kai Chen's avatar
Kai Chen committed
535
        """
Kai Chen's avatar
Kai Chen committed
536
        self.register_lr_hook(lr_config)
537
        self.register_momentum_hook(momentum_config)
Kai Chen's avatar
Kai Chen committed
538
539
        self.register_optimizer_hook(optimizer_config)
        self.register_checkpoint_hook(checkpoint_config)
540
        self.register_timer_hook(timer_config)
Kai Chen's avatar
Kai Chen committed
541
        self.register_logger_hooks(log_config)
542
        self.register_custom_hooks(custom_hooks_config)