test_runner.py 861 Bytes
Newer Older
Kai Chen's avatar
Kai Chen committed
1
# Copyright (c) Open-MMLab. All rights reserved.
2
import logging
3
4
5
import os.path as osp
import tempfile
import warnings
Kai Chen's avatar
Kai Chen committed
6

7
8
9
10

def test_save_checkpoint():
    try:
        import torch
11
        from torch import nn
12
13
14
15
16
17
18
    except ImportError:
        warnings.warn('Skipping test_save_checkpoint in the absense of torch')
        return

    import mmcv.runner

    model = nn.Linear(1, 1)
19
20
    runner = mmcv.runner.Runner(
        model=model, batch_processor=lambda x: x, logger=logging.getLogger())
21
22
23
24
25
26
27
28
29

    with tempfile.TemporaryDirectory() as root:
        runner.save_checkpoint(root)

        latest_path = osp.join(root, 'latest.pth')
        epoch1_path = osp.join(root, 'epoch_1.pth')

        assert osp.exists(latest_path)
        assert osp.exists(epoch1_path)
Kai Chen's avatar
Kai Chen committed
30
        assert osp.realpath(latest_path) == osp.realpath(epoch1_path)
31
32

        torch.load(latest_path)