Unverified Commit 58a84833 authored by Haian Huang(深度眸)'s avatar Haian Huang(深度眸) Committed by GitHub
Browse files

Fix the iter error when the number of GPUs is different during resume (#844)

* Fix the iter error when the number of GPUs is different during resume

* Add fromstring and unit test

* Remove is_pretty_text

* Fix comment

* Add log info

* Add py format check

* Remove SyntaxError check
parent ba30d98a
......@@ -330,6 +330,20 @@ class BaseRunner(metaclass=ABCMeta):
self._epoch = checkpoint['meta']['epoch']
self._iter = checkpoint['meta']['iter']
# Re-calculate the number of iterations when resuming
# models with different number of GPUs
if 'config' in checkpoint['meta']:
config = mmcv.Config.fromstring(
checkpoint['meta']['config'], file_format='.py')
previous_gpu_ids = config.get('gpu_ids', None)
if previous_gpu_ids and len(previous_gpu_ids) > 0 and len(
previous_gpu_ids) != self.world_size:
self._iter = int(self._iter * len(previous_gpu_ids) /
self.world_size)
self.logger.info('the iteration number is changed due to '
'change of GPU number')
if 'optimizer' in checkpoint and resume_optimizer:
if isinstance(self.optimizer, Optimizer):
self.optimizer.load_state_dict(checkpoint['optimizer'])
......
......@@ -5,6 +5,7 @@ import platform
import shutil
import sys
import tempfile
import warnings
from argparse import Action, ArgumentParser
from collections import abc
from importlib import import_module
......@@ -253,6 +254,31 @@ class Config:
import_modules_from_strings(**cfg_dict['custom_imports'])
return Config(cfg_dict, cfg_text=cfg_text, filename=filename)
@staticmethod
def fromstring(cfg_str, file_format):
"""Generate config from config str.
Args:
cfg_str (str): Config str.
file_format (str): Config file format corresponding to the
config str. Only py/yml/yaml/json type are supported now!
Returns:
obj:`Config`: Config obj.
"""
if file_format not in ['.py', '.json', '.yaml', '.yml']:
raise IOError('Only py/yml/yaml/json type are supported now!')
if file_format != '.py' and 'dict(' in cfg_str:
# check if users specify a wrong suffix for python
warnings.warn(
'Please check "file_format", the file format may be .py')
with tempfile.NamedTemporaryFile('w', suffix=file_format) as temp_file:
temp_file.write(cfg_str)
temp_file.flush()
cfg = Config.fromfile(temp_file.name)
return cfg
@staticmethod
def auto_argparser(description=None):
"""Generate argparser from config file automatically (experimental)"""
......
......@@ -161,6 +161,31 @@ def test_fromfile():
Config.fromfile(osp.join(data_path, 'color.jpg'))
def test_fromstring():
for filename in ['a.py', 'a.b.py', 'b.json', 'c.yaml']:
cfg_file = osp.join(data_path, 'config', filename)
file_format = osp.splitext(filename)[-1]
in_cfg = Config.fromfile(cfg_file)
out_cfg = Config.fromstring(in_cfg.pretty_text, '.py')
assert in_cfg._cfg_dict == out_cfg._cfg_dict
cfg_str = open(cfg_file, 'r').read()
out_cfg = Config.fromstring(cfg_str, file_format)
assert in_cfg._cfg_dict == out_cfg._cfg_dict
# test pretty_text only supports py file format
cfg_file = osp.join(data_path, 'config', 'b.json')
in_cfg = Config.fromfile(cfg_file)
with pytest.raises(Exception):
Config.fromstring(in_cfg.pretty_text, '.json')
# test file format error
cfg_str = open(cfg_file, 'r').read()
with pytest.raises(Exception):
Config.fromstring(cfg_str, '.py')
def test_merge_from_base():
cfg_file = osp.join(data_path, 'config/d.py')
cfg = Config.fromfile(cfg_file)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment