test_lr_scheduler.py 1.79 KB
Newer Older
mashun1's avatar
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch

from basicsr.models.lr_scheduler import CosineAnnealingRestartLR

try:
    import matplotlib as mpl
    from matplotlib import pyplot as plt
    from matplotlib import ticker as mtick
except ImportError:
    print('Please install matplotlib.')

mpl.use('Agg')


def main():
    optim_params = [
        {
            'params': [torch.zeros(3, 64, 3, 3)],
            'lr': 4e-4
        },
        {
            'params': [torch.zeros(3, 64, 3, 3)],
            'lr': 2e-4
        },
    ]
    optimizer = torch.optim.Adam(optim_params, lr=2e-4, weight_decay=0, betas=(0.9, 0.99))

    period = [50000, 100000, 150000, 150000, 150000]
    restart_weights = [1, 1, 0.5, 1, 0.5]

    scheduler = CosineAnnealingRestartLR(
        optimizer,
        period,
        restart_weights=restart_weights,
        eta_min=1e-7,
    )

    # draw figure
    total_iter = 600000
    lr_l = list(range(total_iter))
    lr_l2 = list(range(total_iter))
    for i in range(total_iter):
        optimizer.step()
        scheduler.step()
        lr_l[i] = optimizer.param_groups[0]['lr']
        lr_l2[i] = optimizer.param_groups[1]['lr']

    mpl.style.use('default')

    plt.figure(1)
    plt.subplot(111)
    plt.ticklabel_format(style='sci', axis='x', scilimits=(0, 0))
    plt.title('Cosine Annealing Restart Learning Rate Scheme', fontsize=16, color='k')
    plt.plot(list(range(total_iter)), lr_l, linewidth=1.5, label='learning rate 1')
    plt.plot(list(range(total_iter)), lr_l2, linewidth=1.5, label='learning rate 2')
    plt.legend(loc='upper right', shadow=False)
    ax = plt.gca()
    ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.1e'))

    ax.set_ylabel('Learning Rate')
    ax.set_xlabel('Iteration')
    fig = plt.gcf()
    fig.savefig('test_lr_scheduler.png')


if __name__ == '__main__':
    main()