multigridlr.py 1.39 KB
Newer Older
unknown's avatar
unknown committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.runner.hooks.hook import HOOKS
from mmcv.runner.hooks.lr_updater import LrUpdaterHook


@HOOKS.register_module()
class RelativeStepLrUpdaterHook(LrUpdaterHook):
    """RelativeStepLrUpdaterHook.
    Args:
        runner (:obj:`mmcv.Runner`): The runner instance used.
        steps (list[int]): The list of epochs at which decrease
            the learning rate.
        **kwargs (dict): Same as that of mmcv.
    """

    def __init__(self,
                 runner,
                 steps,
                 lrs,
                 warmup_epochs=34,
                 warmuplr_start=0.01,
                 **kwargs):
        super().__init__(**kwargs)
        assert len(steps) == (len(lrs))
        self.steps = steps
        self.lrs = lrs
        self.warmup_epochs = warmup_epochs
        self.warmuplr_start = warmuplr_start
        self.warmuplr_end = self.lrs[0]
        super().before_run(runner)

    def get_lr(self, runner, base_lr):
        """Similar to that of mmcv."""
        progress = runner.epoch if self.by_epoch else runner.iter
        if progress <= self.warmup_epochs:
            alpha = (self.warmuplr_end -
                     self.warmuplr_start) / self.warmup_epochs
            return progress * alpha + self.warmuplr_start
        for i in range(len(self.steps)):
            if progress < self.steps[i]:
                return self.lrs[i]