"src/libtorio/ffmpeg/stream_writer/encoder.cpp" did not exist on "a8bb3973da165fa3287313b6e406928cc36fbc2e"
Unverified Commit 3d6032c6 authored by Qiaofei Li's avatar Qiaofei Li Committed by GitHub
Browse files

Add dataset classes name info to meta for saving ckpt (#776)

* add CLASSES to meta info

* Update checkpoint.py

* add unit test for CLASSES name

* clean up the tmp folder

* use tempfile to clean up temp folder
parent 905c9b43
......@@ -420,6 +420,10 @@ def save_checkpoint(model, filename, optimizer=None, meta=None):
if is_module_wrapper(model):
model = model.module
if hasattr(model, 'CLASSES') and model.CLASSES is not None:
# save class name to the meta
meta.update(CLASSES=model.CLASSES)
checkpoint = {
'meta': meta,
'state_dict': weights_to_cpu(get_state_dict(model))
......
......@@ -133,3 +133,35 @@ def test_load_pavimodel_dist():
with pytest.raises(FileNotFoundError):
# there is not such checkpoint for us to load
_ = load_pavimodel_dist('MyPaviFolder/checkpoint.pth')
def test_load_classes_name():
from mmcv.runner import load_checkpoint, save_checkpoint
import tempfile
import os
checkpoint_path = os.path.join(tempfile.gettempdir(), 'checkpoint.pth')
model = Model()
save_checkpoint(model, checkpoint_path)
checkpoint = load_checkpoint(model, checkpoint_path)
assert 'meta' in checkpoint and 'CLASSES' not in checkpoint['meta']
model.CLASSES = ('class1', 'class2')
save_checkpoint(model, checkpoint_path)
checkpoint = load_checkpoint(model, checkpoint_path)
assert 'meta' in checkpoint and 'CLASSES' in checkpoint['meta']
assert checkpoint['meta']['CLASSES'] == ('class1', 'class2')
model = Model()
wrapped_model = DDPWrapper(model)
save_checkpoint(wrapped_model, checkpoint_path)
checkpoint = load_checkpoint(wrapped_model, checkpoint_path)
assert 'meta' in checkpoint and 'CLASSES' not in checkpoint['meta']
wrapped_model.module.CLASSES = ('class1', 'class2')
save_checkpoint(wrapped_model, checkpoint_path)
checkpoint = load_checkpoint(wrapped_model, checkpoint_path)
assert 'meta' in checkpoint and 'CLASSES' in checkpoint['meta']
assert checkpoint['meta']['CLASSES'] == ('class1', 'class2')
# remove the temp file
os.remove(checkpoint_path)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment