test_runner.py 747 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import os.path as osp
import tempfile
import warnings


def test_save_checkpoint():
    try:
        import torch
        import torch.nn as nn
    except ImportError:
        warnings.warn('Skipping test_save_checkpoint in the absense of torch')
        return

    import mmcv.runner

    model = nn.Linear(1, 1)
Kai Chen's avatar
Kai Chen committed
17
    runner = mmcv.runner.Runner(model=model, batch_processor=lambda x: x)
18
19
20
21
22
23
24
25
26
27
28
29

    with tempfile.TemporaryDirectory() as root:
        runner.save_checkpoint(root)

        latest_path = osp.join(root, 'latest.pth')
        epoch1_path = osp.join(root, 'epoch_1.pth')

        assert osp.exists(latest_path)
        assert osp.exists(epoch1_path)
        assert osp.realpath(latest_path) == epoch1_path

        torch.load(latest_path)