lr_scheduler_test.py 646 Bytes
Newer Older
dengjb's avatar
update  
dengjb committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import sys
import unittest

import torch
from torch import nn

sys.path.append('.')
from solver.lr_scheduler import WarmupMultiStepLR
from solver.build import make_optimizer
from config import cfg


class MyTestCase(unittest.TestCase):
    def test_something(self):
        net = nn.Linear(10, 10)
        optimizer = make_optimizer(cfg, net)
        lr_scheduler = WarmupMultiStepLR(optimizer, [20, 40], warmup_iters=10)
        for i in range(50):
            lr_scheduler.step()
            for j in range(3):
                print(i, lr_scheduler.get_lr()[0])
                optimizer.step()


if __name__ == '__main__':
    unittest.main()