Commit 395efbf2 authored by Evgeny Nizhibitsky's avatar Evgeny Nizhibitsky Committed by Kai Chen
Browse files

Make latest checkpoint symlink relative (#75)

* Make latest symlink relative. Add tests

* Skip test_save_checkpoint in the absense of torch

* Fix pep8 in test_runner

* Import runner only if torch import has succeeded
parent 123720b6
......@@ -243,11 +243,13 @@ class Runner(object):
else:
meta.update(epoch=self.epoch + 1, iter=self.iter)
filename = osp.join(out_dir, filename_tmpl.format(self.epoch + 1))
linkname = osp.join(out_dir, 'latest.pth')
filename = filename_tmpl.format(self.epoch + 1)
filepath = osp.join(out_dir, filename)
linkpath = osp.join(out_dir, 'latest.pth')
optimizer = self.optimizer if save_optimizer else None
save_checkpoint(self.model, filename, optimizer=optimizer, meta=meta)
mmcv.symlink(filename, linkname)
save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta)
# use relative symlink
mmcv.symlink(filename, linkpath)
def train(self, data_loader, **kwargs):
self.model.train()
......
import os.path as osp
import tempfile
import warnings
def test_save_checkpoint():
try:
import torch
import torch.nn as nn
except ImportError:
warnings.warn('Skipping test_save_checkpoint in the absense of torch')
return
import mmcv.runner
model = nn.Linear(1, 1)
runner = mmcv.runner.Runner(
model=model,
batch_processor=lambda x: x
)
with tempfile.TemporaryDirectory() as root:
runner.save_checkpoint(root)
latest_path = osp.join(root, 'latest.pth')
epoch1_path = osp.join(root, 'epoch_1.pth')
assert osp.exists(latest_path)
assert osp.exists(epoch1_path)
assert osp.realpath(latest_path) == epoch1_path
torch.load(latest_path)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment