base_runner.py 21.8 KB
Newer Older
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import copy
Kai Chen's avatar
Kai Chen committed
3
4
import logging
import os.path as osp
5
6
import warnings
from abc import ABCMeta, abstractmethod
7
8
9
from collections import OrderedDict
from typing import (Any, Callable, Dict, List, Optional, Tuple, Union,
                    no_type_check)
Kai Chen's avatar
Kai Chen committed
10
11

import torch
Harry's avatar
Harry committed
12
from torch.optim import Optimizer
13
from torch.utils.data import DataLoader
Kai Chen's avatar
Kai Chen committed
14

Kai Chen's avatar
Kai Chen committed
15
import mmcv
Harry's avatar
Harry committed
16
from ..parallel import is_module_wrapper
17
from .checkpoint import load_checkpoint
Kai Chen's avatar
Kai Chen committed
18
from .dist_utils import get_dist_info
19
from .hooks import HOOKS, Hook
Kai Chen's avatar
Kai Chen committed
20
from .log_buffer import LogBuffer
21
from .priority import Priority, get_priority
22
from .utils import get_time_str
Kai Chen's avatar
Kai Chen committed
23
24


25
26
27
28
29
30
31
32
33
class BaseRunner(metaclass=ABCMeta):
    """The base class of Runner, a training helper for PyTorch.

    All subclasses should implement the following APIs:

    - ``run()``
    - ``train()``
    - ``val()``
    - ``save_checkpoint()``
Kai Chen's avatar
Kai Chen committed
34
35
36
37
38
39

    Args:
        model (:obj:`torch.nn.Module`): The model to be run.
        batch_processor (callable): A callable method that process a data
            batch. The interface of this method should be
            `batch_processor(model, data, train_mode) -> dict`
Harry's avatar
Harry committed
40
41
42
        optimizer (dict or :obj:`torch.optim.Optimizer`): It can be either an
            optimizer (in most cases) or a dict of optimizers (in models that
            requires more than one optimizer, e.g., GAN).
Kai Chen's avatar
Kai Chen committed
43
        work_dir (str, optional): The working directory to save checkpoints
44
45
            and logs. Defaults to None.
        logger (:obj:`logging.Logger`): Logger used during training.
Harry's avatar
Harry committed
46
47
             Defaults to None. (The default value is just for backward
             compatibility)
Cao Yuhang's avatar
Cao Yuhang committed
48
49
        meta (dict | None): A dict records some import information such as
            environment info and seed, which will be logged in logger hook.
50
            Defaults to None.
51
52
        max_epochs (int, optional): Total training epochs.
        max_iters (int, optional): Total training iterations.
Kai Chen's avatar
Kai Chen committed
53
    """
Kai Chen's avatar
Kai Chen committed
54
55

    def __init__(self,
56
57
58
59
60
61
62
63
                 model: torch.nn.Module,
                 batch_processor: Optional[Callable] = None,
                 optimizer: Union[Dict, torch.optim.Optimizer, None] = None,
                 work_dir: Optional[str] = None,
                 logger: Optional[logging.Logger] = None,
                 meta: Optional[Dict] = None,
                 max_iters: Optional[int] = None,
                 max_epochs: Optional[int] = None) -> None:
64
65
66
67
        if batch_processor is not None:
            if not callable(batch_processor):
                raise TypeError('batch_processor must be callable, '
                                f'but got {type(batch_processor)}')
68
69
70
71
            warnings.warn(
                'batch_processor is deprecated, please implement '
                'train_step() and val_step() in the model instead.',
                DeprecationWarning)
72
73
            # raise an error is `batch_processor` is not None and
            # `model.train_step()` exists.
Harry's avatar
Harry committed
74
            if is_module_wrapper(model):
Kai Chen's avatar
Kai Chen committed
75
76
77
78
                _model = model.module
            else:
                _model = model
            if hasattr(_model, 'train_step') or hasattr(_model, 'val_step'):
79
80
81
                raise RuntimeError(
                    'batch_processor and model.train_step()/model.val_step() '
                    'cannot be both available.')
82
        else:
83
            assert hasattr(model, 'train_step')
Harry's avatar
Harry committed
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106

        # check the type of `optimizer`
        if isinstance(optimizer, dict):
            for name, optim in optimizer.items():
                if not isinstance(optim, Optimizer):
                    raise TypeError(
                        f'optimizer must be a dict of torch.optim.Optimizers, '
                        f'but optimizer["{name}"] is a {type(optim)}')
        elif not isinstance(optimizer, Optimizer) and optimizer is not None:
            raise TypeError(
                f'optimizer must be a torch.optim.Optimizer object '
                f'or dict or None, but got {type(optimizer)}')

        # check the type of `logger`
        if not isinstance(logger, logging.Logger):
            raise TypeError(f'logger must be a logging.Logger object, '
                            f'but got {type(logger)}')

        # check the type of `meta`
        if meta is not None and not isinstance(meta, dict):
            raise TypeError(
                f'meta must be a dict or None, but got {type(meta)}')

107
        self.model = model
Kai Chen's avatar
Kai Chen committed
108
        self.batch_processor = batch_processor
109
        self.optimizer = optimizer
Harry's avatar
Harry committed
110
111
        self.logger = logger
        self.meta = meta
Kai Chen's avatar
Kai Chen committed
112
        # create work_dir
113
114
        if isinstance(work_dir, str):
            self.work_dir: Optional[str] = osp.abspath(work_dir)
Kai Chen's avatar
Kai Chen committed
115
116
117
118
119
120
121
            mmcv.mkdir_or_exist(self.work_dir)
        elif work_dir is None:
            self.work_dir = None
        else:
            raise TypeError('"work_dir" must be a str or None')

        # get model name from the model class
122
        if hasattr(self.model, 'module'):
Kai Chen's avatar
Kai Chen committed
123
124
125
126
            self._model_name = self.model.module.__class__.__name__
        else:
            self._model_name = self.model.__class__.__name__

127
128
        self._rank, self._world_size = get_dist_info()
        self.timestamp = get_time_str()
129
130
        self.mode: Optional[str] = None
        self._hooks: List[Hook] = []
Kai Chen's avatar
Kai Chen committed
131
132
133
        self._epoch = 0
        self._iter = 0
        self._inner_iter = 0
134
135
136
137
138
139
140

        if max_epochs is not None and max_iters is not None:
            raise ValueError(
                'Only one of `max_epochs` or `max_iters` can be set.')

        self._max_epochs = max_epochs
        self._max_iters = max_iters
141
142
        # TODO: Redesign LogBuffer, it is not flexible and elegant enough
        self.log_buffer = LogBuffer()
Kai Chen's avatar
Kai Chen committed
143
144

    @property
145
    def model_name(self) -> str:
Kai Chen's avatar
Kai Chen committed
146
147
148
149
        """str: Name of the model, usually the module class name."""
        return self._model_name

    @property
150
    def rank(self) -> int:
Kai Chen's avatar
Kai Chen committed
151
152
153
154
        """int: Rank of current process. (distributed training)"""
        return self._rank

    @property
155
    def world_size(self) -> int:
Kai Chen's avatar
Kai Chen committed
156
157
158
159
160
        """int: Number of processes participating in the job.
        (distributed training)"""
        return self._world_size

    @property
161
    def hooks(self) -> List[Hook]:
Kai Chen's avatar
Kai Chen committed
162
163
164
165
        """list[:obj:`Hook`]: A list of registered hooks."""
        return self._hooks

    @property
166
    def epoch(self) -> int:
Kai Chen's avatar
Kai Chen committed
167
168
169
170
        """int: Current epoch."""
        return self._epoch

    @property
171
    def iter(self) -> int:
Kai Chen's avatar
Kai Chen committed
172
173
174
175
        """int: Current iteration."""
        return self._iter

    @property
176
    def inner_iter(self) -> int:
Kai Chen's avatar
Kai Chen committed
177
178
179
180
181
182
183
184
185
186
187
188
189
        """int: Iteration in an epoch."""
        return self._inner_iter

    @property
    def max_epochs(self):
        """int: Maximum training epochs."""
        return self._max_epochs

    @property
    def max_iters(self):
        """int: Maximum training iterations."""
        return self._max_iters

190
191
192
    @abstractmethod
    def train(self):
        pass
Kai Chen's avatar
Kai Chen committed
193

194
195
196
    @abstractmethod
    def val(self):
        pass
Kai Chen's avatar
Kai Chen committed
197

198
    @abstractmethod
199
200
    def run(self, data_loaders: List[DataLoader],
            workflow: List[Tuple[str, int]], **kwargs) -> Any:
201
        pass
Kai Chen's avatar
Kai Chen committed
202

203
204
    @abstractmethod
    def save_checkpoint(self,
205
206
207
208
209
                        out_dir: str,
                        filename_tmpl: str,
                        save_optimizer: bool = True,
                        meta: Optional[Dict] = None,
                        create_symlink: bool = True) -> None:
210
        pass
Kai Chen's avatar
Kai Chen committed
211

212
    def current_lr(self) -> Union[List[float], Dict[str, List[float]]]:
Kai Chen's avatar
Kai Chen committed
213
214
215
        """Get current learning rates.

        Returns:
Harry's avatar
Harry committed
216
            list[float] | dict[str, list[float]]: Current learning rates of all
217
218
            param groups. If the runner has a dict of optimizers, this method
            will return a dict.
Kai Chen's avatar
Kai Chen committed
219
        """
220
        lr: Union[List[float], Dict[str, List[float]]]
Harry's avatar
Harry committed
221
222
223
224
225
226
227
        if isinstance(self.optimizer, torch.optim.Optimizer):
            lr = [group['lr'] for group in self.optimizer.param_groups]
        elif isinstance(self.optimizer, dict):
            lr = dict()
            for name, optim in self.optimizer.items():
                lr[name] = [group['lr'] for group in optim.param_groups]
        else:
228
229
            raise RuntimeError(
                'lr is not applicable because optimizer does not exist.')
Harry's avatar
Harry committed
230
        return lr
Kai Chen's avatar
Kai Chen committed
231

232
    def current_momentum(self) -> Union[List[float], Dict[str, List[float]]]:
Wenwei Zhang's avatar
Wenwei Zhang committed
233
234
235
        """Get current momentums.

        Returns:
Harry's avatar
Harry committed
236
            list[float] | dict[str, list[float]]: Current momentums of all
237
238
            param groups. If the runner has a dict of optimizers, this method
            will return a dict.
Wenwei Zhang's avatar
Wenwei Zhang committed
239
        """
Harry's avatar
Harry committed
240
241
242
243
244
245
246
247
248
249
250
251

        def _get_momentum(optimizer):
            momentums = []
            for group in optimizer.param_groups:
                if 'momentum' in group.keys():
                    momentums.append(group['momentum'])
                elif 'betas' in group.keys():
                    momentums.append(group['betas'][0])
                else:
                    momentums.append(0)
            return momentums

Wenwei Zhang's avatar
Wenwei Zhang committed
252
253
        if self.optimizer is None:
            raise RuntimeError(
254
                'momentum is not applicable because optimizer does not exist.')
Harry's avatar
Harry committed
255
256
257
258
259
260
        elif isinstance(self.optimizer, torch.optim.Optimizer):
            momentums = _get_momentum(self.optimizer)
        elif isinstance(self.optimizer, dict):
            momentums = dict()
            for name, optim in self.optimizer.items():
                momentums[name] = _get_momentum(optim)
261
        return momentums
Wenwei Zhang's avatar
Wenwei Zhang committed
262

263
264
265
    def register_hook(self,
                      hook: Hook,
                      priority: Union[int, str, Priority] = 'NORMAL') -> None:
Kai Chen's avatar
Kai Chen committed
266
267
        """Register a hook into the hook list.

268
        The hook will be inserted into a priority queue, with the specified
Kai Chen's avatar
Kai Chen committed
269
        priority (See :class:`Priority` for details of priorities).
270
271
272
        For hooks with the same priority, they will be triggered in the same
        order as they are registered.

Kai Chen's avatar
Kai Chen committed
273
274
        Args:
            hook (:obj:`Hook`): The hook to be registered.
Kai Chen's avatar
Kai Chen committed
275
276
            priority (int or str or :obj:`Priority`): Hook priority.
                Lower value means higher priority.
Kai Chen's avatar
Kai Chen committed
277
278
279
280
        """
        assert isinstance(hook, Hook)
        if hasattr(hook, 'priority'):
            raise ValueError('"priority" is a reserved attribute for hooks')
Kai Chen's avatar
Kai Chen committed
281
        priority = get_priority(priority)
282
        hook.priority = priority  # type: ignore
Kai Chen's avatar
Kai Chen committed
283
284
285
        # insert the hook to a sorted list
        inserted = False
        for i in range(len(self._hooks) - 1, -1, -1):
286
            if priority >= self._hooks[i].priority:  # type: ignore
Kai Chen's avatar
Kai Chen committed
287
288
289
290
291
292
                self._hooks.insert(i + 1, hook)
                inserted = True
                break
        if not inserted:
            self._hooks.insert(0, hook)

293
    def register_hook_from_cfg(self, hook_cfg: Dict) -> None:
Wang Xinjiang's avatar
Wang Xinjiang committed
294
295
296
297
298
299
        """Register a hook from its cfg.

        Args:
            hook_cfg (dict): Hook config. It should have at least keys 'type'
              and 'priority' indicating its type and priority.

300
        Note:
Wang Xinjiang's avatar
Wang Xinjiang committed
301
302
303
304
305
306
307
308
            The specific hook class to register should not use 'type' and
            'priority' arguments during initialization.
        """
        hook_cfg = hook_cfg.copy()
        priority = hook_cfg.pop('priority', 'NORMAL')
        hook = mmcv.build_from_cfg(hook_cfg, HOOKS)
        self.register_hook(hook, priority=priority)

309
    def call_hook(self, fn_name: str) -> None:
310
311
312
313
314
315
        """Call all hooks.

        Args:
            fn_name (str): The function name in each hook to be called, such as
                "before_train_epoch".
        """
Kai Chen's avatar
Kai Chen committed
316
317
318
        for hook in self._hooks:
            getattr(hook, fn_name)(self)

319
    def get_hook_info(self) -> str:
320
        # Get hooks info in each stage
321
        stage_hook_map: Dict[str, list] = {stage: [] for stage in Hook.stages}
322
323
        for hook in self.hooks:
            try:
324
                priority = Priority(hook.priority).name  # type: ignore
325
            except ValueError:
326
                priority = hook.priority  # type: ignore
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
            classname = hook.__class__.__name__
            hook_info = f'({priority:<12}) {classname:<35}'
            for trigger_stage in hook.get_triggered_stages():
                stage_hook_map[trigger_stage].append(hook_info)

        stage_hook_infos = []
        for stage in Hook.stages:
            hook_infos = stage_hook_map[stage]
            if len(hook_infos) > 0:
                info = f'{stage}:\n'
                info += '\n'.join(hook_infos)
                info += '\n -------------------- '
                stage_hook_infos.append(info)
        return '\n'.join(stage_hook_infos)

342
343
344
345
346
347
348
    def load_checkpoint(
        self,
        filename: str,
        map_location: Union[str, Callable] = 'cpu',
        strict: bool = False,
        revise_keys: List = [(r'^module.', '')],
    ) -> Union[Dict, OrderedDict]:
349
350
351
352
353
354
355
        return load_checkpoint(
            self.model,
            filename,
            map_location,
            strict,
            self.logger,
            revise_keys=revise_keys)
Kai Chen's avatar
Kai Chen committed
356

357
    @no_type_check
358
    def resume(self,
359
360
361
               checkpoint: str,
               resume_optimizer: bool = True,
               map_location: Union[str, Callable] = 'default') -> None:
Kai Chen's avatar
Kai Chen committed
362
        if map_location == 'default':
shilong's avatar
shilong committed
363
364
365
366
367
368
369
            if torch.cuda.is_available():
                device_id = torch.cuda.current_device()
                checkpoint = self.load_checkpoint(
                    checkpoint,
                    map_location=lambda storage, loc: storage.cuda(device_id))
            else:
                checkpoint = self.load_checkpoint(checkpoint)
Kai Chen's avatar
Kai Chen committed
370
371
372
373
374
375
        else:
            checkpoint = self.load_checkpoint(
                checkpoint, map_location=map_location)

        self._epoch = checkpoint['meta']['epoch']
        self._iter = checkpoint['meta']['iter']
376
377
378
379
380
        if self.meta is None:
            self.meta = {}
        self.meta.setdefault('hook_msgs', {})
        # load `last_ckpt`, `best_score`, `best_ckpt`, etc. for hook messages
        self.meta['hook_msgs'].update(checkpoint['meta'].get('hook_msgs', {}))
381
382
383
384
385
386
387
388
389
390
391
392
393
394

        # Re-calculate the number of iterations when resuming
        # models with different number of GPUs
        if 'config' in checkpoint['meta']:
            config = mmcv.Config.fromstring(
                checkpoint['meta']['config'], file_format='.py')
            previous_gpu_ids = config.get('gpu_ids', None)
            if previous_gpu_ids and len(previous_gpu_ids) > 0 and len(
                    previous_gpu_ids) != self.world_size:
                self._iter = int(self._iter * len(previous_gpu_ids) /
                                 self.world_size)
                self.logger.info('the iteration number is changed due to '
                                 'change of GPU number')

395
396
397
        # resume meta information meta
        self.meta = checkpoint['meta']

Kai Chen's avatar
Kai Chen committed
398
        if 'optimizer' in checkpoint and resume_optimizer:
399
400
401
402
403
404
405
406
407
408
            if isinstance(self.optimizer, Optimizer):
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            elif isinstance(self.optimizer, dict):
                for k in self.optimizer.keys():
                    self.optimizer[k].load_state_dict(
                        checkpoint['optimizer'][k])
            else:
                raise TypeError(
                    'Optimizer should be dict or torch.optim.Optimizer '
                    f'but got {type(self.optimizer)}')
Kai Chen's avatar
Kai Chen committed
409
410
411

        self.logger.info('resumed epoch %d, iter %d', self.epoch, self.iter)

412
    def register_lr_hook(self, lr_config: Union[Dict, Hook, None]) -> None:
413
414
415
        if lr_config is None:
            return
        elif isinstance(lr_config, dict):
Kai Chen's avatar
Kai Chen committed
416
            assert 'policy' in lr_config
417
418
419
            policy_type = lr_config.pop('policy')
            # If the type of policy is all in lower case, e.g., 'cyclic',
            # then its first letter will be capitalized, e.g., to be 'Cyclic'.
Ye Liu's avatar
Ye Liu committed
420
            # This is for the convenient usage of Lr updater.
Yawei Li's avatar
Yawei Li committed
421
422
            # Since this is not applicable for `
            # CosineAnnealingLrUpdater`,
423
424
425
426
            # the string will not be changed if it contains capital letters.
            if policy_type == policy_type.lower():
                policy_type = policy_type.title()
            hook_type = policy_type + 'LrUpdaterHook'
Kai Chen's avatar
Kai Chen committed
427
428
429
430
            lr_config['type'] = hook_type
            hook = mmcv.build_from_cfg(lr_config, HOOKS)
        else:
            hook = lr_config
431
        self.register_hook(hook, priority='VERY_HIGH')
Kai Chen's avatar
Kai Chen committed
432

433
434
    def register_momentum_hook(
            self, momentum_config: Union[Dict, Hook, None]) -> None:
Wenwei Zhang's avatar
Wenwei Zhang committed
435
436
437
438
        if momentum_config is None:
            return
        if isinstance(momentum_config, dict):
            assert 'policy' in momentum_config
439
440
441
442
            policy_type = momentum_config.pop('policy')
            # If the type of policy is all in lower case, e.g., 'cyclic',
            # then its first letter will be capitalized, e.g., to be 'Cyclic'.
            # This is for the convenient usage of momentum updater.
Yawei Li's avatar
Yawei Li committed
443
444
            # Since this is not applicable for
            # `CosineAnnealingMomentumUpdater`,
445
446
447
448
            # the string will not be changed if it contains capital letters.
            if policy_type == policy_type.lower():
                policy_type = policy_type.title()
            hook_type = policy_type + 'MomentumUpdaterHook'
Wenwei Zhang's avatar
Wenwei Zhang committed
449
450
451
452
            momentum_config['type'] = hook_type
            hook = mmcv.build_from_cfg(momentum_config, HOOKS)
        else:
            hook = momentum_config
453
        self.register_hook(hook, priority='HIGH')
Wenwei Zhang's avatar
Wenwei Zhang committed
454

455
456
    def register_optimizer_hook(
            self, optimizer_config: Union[Dict, Hook, None]) -> None:
457
458
459
460
461
462
463
        if optimizer_config is None:
            return
        if isinstance(optimizer_config, dict):
            optimizer_config.setdefault('type', 'OptimizerHook')
            hook = mmcv.build_from_cfg(optimizer_config, HOOKS)
        else:
            hook = optimizer_config
464
        self.register_hook(hook, priority='ABOVE_NORMAL')
465

466
467
    def register_checkpoint_hook(
            self, checkpoint_config: Union[Dict, Hook, None]) -> None:
468
469
470
471
472
473
474
        if checkpoint_config is None:
            return
        if isinstance(checkpoint_config, dict):
            checkpoint_config.setdefault('type', 'CheckpointHook')
            hook = mmcv.build_from_cfg(checkpoint_config, HOOKS)
        else:
            hook = checkpoint_config
475
        self.register_hook(hook, priority='NORMAL')
476

477
    def register_logger_hooks(self, log_config: Optional[Dict]) -> None:
su's avatar
su committed
478
479
        if log_config is None:
            return
Kai Chen's avatar
Kai Chen committed
480
481
        log_interval = log_config['interval']
        for info in log_config['hooks']:
Kai Chen's avatar
Kai Chen committed
482
483
            logger_hook = mmcv.build_from_cfg(
                info, HOOKS, default_args=dict(interval=log_interval))
484
            self.register_hook(logger_hook, priority='VERY_LOW')
Kai Chen's avatar
Kai Chen committed
485

486
487
488
489
    def register_timer_hook(
        self,
        timer_config: Union[Dict, Hook, None],
    ) -> None:
490
491
492
493
        if timer_config is None:
            return
        if isinstance(timer_config, dict):
            timer_config_ = copy.deepcopy(timer_config)
Miao Zheng's avatar
Miao Zheng committed
494
            hook = mmcv.build_from_cfg(timer_config_, HOOKS)
495
496
        else:
            hook = timer_config
497
        self.register_hook(hook, priority='LOW')
498

499
500
    def register_custom_hooks(
            self, custom_config: Union[List, Dict, Hook, None]) -> None:
501
502
503
504
505
506
507
508
509
510
511
        if custom_config is None:
            return

        if not isinstance(custom_config, list):
            custom_config = [custom_config]

        for item in custom_config:
            if isinstance(item, dict):
                self.register_hook_from_cfg(item)
            else:
                self.register_hook(item, priority='NORMAL')
512

513
514
515
516
    def register_profiler_hook(
        self,
        profiler_config: Union[Dict, Hook, None],
    ) -> None:
517
518
519
520
521
522
523
524
525
        if profiler_config is None:
            return
        if isinstance(profiler_config, dict):
            profiler_config.setdefault('type', 'ProfilerHook')
            hook = mmcv.build_from_cfg(profiler_config, HOOKS)
        else:
            hook = profiler_config
        self.register_hook(hook)

526
527
528
529
530
531
532
533
534
    def register_training_hooks(
            self,
            lr_config: Union[Dict, Hook, None],
            optimizer_config: Union[Dict, Hook, None] = None,
            checkpoint_config: Union[Dict, Hook, None] = None,
            log_config: Optional[Dict] = None,
            momentum_config: Union[Dict, Hook, None] = None,
            timer_config: Union[Dict, Hook] = dict(type='IterTimerHook'),
            custom_hooks_config: Union[List, Dict, Hook, None] = None) -> None:
535
536
537
538
        """Register default and custom hooks for training.

        Default and custom hooks include:

539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
        +----------------------+-------------------------+
        | Hooks                | Priority                |
        +======================+=========================+
        | LrUpdaterHook        | VERY_HIGH (10)          |
        +----------------------+-------------------------+
        | MomentumUpdaterHook  | HIGH (30)               |
        +----------------------+-------------------------+
        | OptimizerStepperHook | ABOVE_NORMAL (40)       |
        +----------------------+-------------------------+
        | CheckpointSaverHook  | NORMAL (50)             |
        +----------------------+-------------------------+
        | IterTimerHook        | LOW (70)                |
        +----------------------+-------------------------+
        | LoggerHook(s)        | VERY_LOW (90)           |
        +----------------------+-------------------------+
        | CustomHook(s)        | defaults to NORMAL (50) |
        +----------------------+-------------------------+

        If custom hooks have same priority with default hooks, custom hooks
        will be triggered after default hooks.
Kai Chen's avatar
Kai Chen committed
559
        """
Kai Chen's avatar
Kai Chen committed
560
        self.register_lr_hook(lr_config)
561
        self.register_momentum_hook(momentum_config)
Kai Chen's avatar
Kai Chen committed
562
563
        self.register_optimizer_hook(optimizer_config)
        self.register_checkpoint_hook(checkpoint_config)
564
        self.register_timer_hook(timer_config)
Kai Chen's avatar
Kai Chen committed
565
        self.register_logger_hooks(log_config)
566
        self.register_custom_hooks(custom_hooks_config)