Unverified Commit 55fadb4c authored by Jiamin's avatar Jiamin Committed by GitHub
Browse files

Add runner.meta to checkpoint in save_checkpoint() (#438)

* fix: error when runner.meta is None

* tests: add unittest for epoch-based save_checkpoint
parent 215a3244
......@@ -147,8 +147,13 @@ class EpochBasedRunner(BaseRunner):
"""
if meta is None:
meta = dict(epoch=self.epoch + 1, iter=self.iter)
else:
elif isinstance(meta, dict):
meta.update(epoch=self.epoch + 1, iter=self.iter)
else:
raise TypeError(
f'meta should be a dict or None, but got {type(meta)}')
if self.meta is not None:
meta.update(self.meta)
filename = filename_tmpl.format(self.epoch + 1)
filepath = osp.join(out_dir, filename)
......
......@@ -183,7 +183,8 @@ class IterBasedRunner(BaseRunner):
else:
raise TypeError(
f'meta should be a dict or None, but got {type(meta)}')
meta.update(self.meta)
if self.meta is not None:
meta.update(self.meta)
filename = filename_tmpl.format(self.iter + 1)
filepath = osp.join(out_dir, filename)
......
......@@ -131,6 +131,10 @@ def test_save_checkpoint():
model = Model()
runner = EpochBasedRunner(model=model, logger=logging.getLogger())
with pytest.raises(TypeError):
# meta should be None or dict
runner.save_checkpoint('.', meta=list())
with tempfile.TemporaryDirectory() as root:
runner.save_checkpoint(root)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment