lr_updater.py 1.83 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.runner import HOOKS, LrUpdaterHook


@HOOKS.register_module()
class LinearLrUpdaterHook(LrUpdaterHook):
    """Linear learning rate scheduler for image generation.

    In the beginning, the learning rate is 'base_lr' defined in mmcv.
    We give a target learning rate 'target_lr' and a start point 'start'
    (iteration / epoch). Before 'start', we fix learning rate as 'base_lr';
    After 'start', we linearly update learning rate to 'target_lr'.

    Args:
        target_lr (float): The target learning rate. Default: 0.
        start (int): The start point (iteration / epoch, specified by args
            'by_epoch' in its parent class in mmcv) to update learning rate.
            Default: 0.
        interval (int): The interval to update the learning rate. Default: 1.
    """

    def __init__(self, target_lr=0, start=0, interval=1, **kwargs):
        super().__init__(**kwargs)
        self.target_lr = target_lr
        self.start = start
        self.interval = interval

    def get_lr(self, runner, base_lr):
        """Calculates the learning rate.

        Args:
            runner (object): The passed runner.
            base_lr (float): Base learning rate.

        Returns:
            float: Current learning rate.
        """
        if self.by_epoch:
            progress = runner.epoch
            max_progress = runner.max_epochs
        else:
            progress = runner.iter
            max_progress = runner.max_iters
        assert max_progress >= self.start
        if max_progress == self.start:
            return base_lr

        # Before 'start', fix lr; After 'start', linearly update lr.
        factor = (max(0, progress - self.start) // self.interval) / (
            (max_progress - self.start) // self.interval)
        return base_lr + (self.target_lr - base_lr) * factor