Unverified Commit 58a84833 authored by Haian Huang(深度眸)'s avatar Haian Huang(深度眸) Committed by GitHub
Browse files

Fix the iter error when the number of GPUs is different during resume (#844)

* Fix the iter error when the number of GPUs is different during resume

* Add fromstring and unit test

* Remove is_pretty_text

* Fix comment

* Add log info

* Add py format check

* Remove SyntaxError check
parent ba30d98a
...@@ -330,6 +330,20 @@ class BaseRunner(metaclass=ABCMeta): ...@@ -330,6 +330,20 @@ class BaseRunner(metaclass=ABCMeta):
self._epoch = checkpoint['meta']['epoch'] self._epoch = checkpoint['meta']['epoch']
self._iter = checkpoint['meta']['iter'] self._iter = checkpoint['meta']['iter']
# Re-calculate the number of iterations when resuming
# models with different number of GPUs
if 'config' in checkpoint['meta']:
config = mmcv.Config.fromstring(
checkpoint['meta']['config'], file_format='.py')
previous_gpu_ids = config.get('gpu_ids', None)
if previous_gpu_ids and len(previous_gpu_ids) > 0 and len(
previous_gpu_ids) != self.world_size:
self._iter = int(self._iter * len(previous_gpu_ids) /
self.world_size)
self.logger.info('the iteration number is changed due to '
'change of GPU number')
if 'optimizer' in checkpoint and resume_optimizer: if 'optimizer' in checkpoint and resume_optimizer:
if isinstance(self.optimizer, Optimizer): if isinstance(self.optimizer, Optimizer):
self.optimizer.load_state_dict(checkpoint['optimizer']) self.optimizer.load_state_dict(checkpoint['optimizer'])
......
...@@ -5,6 +5,7 @@ import platform ...@@ -5,6 +5,7 @@ import platform
import shutil import shutil
import sys import sys
import tempfile import tempfile
import warnings
from argparse import Action, ArgumentParser from argparse import Action, ArgumentParser
from collections import abc from collections import abc
from importlib import import_module from importlib import import_module
...@@ -253,6 +254,31 @@ class Config: ...@@ -253,6 +254,31 @@ class Config:
import_modules_from_strings(**cfg_dict['custom_imports']) import_modules_from_strings(**cfg_dict['custom_imports'])
return Config(cfg_dict, cfg_text=cfg_text, filename=filename) return Config(cfg_dict, cfg_text=cfg_text, filename=filename)
@staticmethod
def fromstring(cfg_str, file_format):
"""Generate config from config str.
Args:
cfg_str (str): Config str.
file_format (str): Config file format corresponding to the
config str. Only py/yml/yaml/json type are supported now!
Returns:
obj:`Config`: Config obj.
"""
if file_format not in ['.py', '.json', '.yaml', '.yml']:
raise IOError('Only py/yml/yaml/json type are supported now!')
if file_format != '.py' and 'dict(' in cfg_str:
# check if users specify a wrong suffix for python
warnings.warn(
'Please check "file_format", the file format may be .py')
with tempfile.NamedTemporaryFile('w', suffix=file_format) as temp_file:
temp_file.write(cfg_str)
temp_file.flush()
cfg = Config.fromfile(temp_file.name)
return cfg
@staticmethod @staticmethod
def auto_argparser(description=None): def auto_argparser(description=None):
"""Generate argparser from config file automatically (experimental)""" """Generate argparser from config file automatically (experimental)"""
......
...@@ -161,6 +161,31 @@ def test_fromfile(): ...@@ -161,6 +161,31 @@ def test_fromfile():
Config.fromfile(osp.join(data_path, 'color.jpg')) Config.fromfile(osp.join(data_path, 'color.jpg'))
def test_fromstring():
for filename in ['a.py', 'a.b.py', 'b.json', 'c.yaml']:
cfg_file = osp.join(data_path, 'config', filename)
file_format = osp.splitext(filename)[-1]
in_cfg = Config.fromfile(cfg_file)
out_cfg = Config.fromstring(in_cfg.pretty_text, '.py')
assert in_cfg._cfg_dict == out_cfg._cfg_dict
cfg_str = open(cfg_file, 'r').read()
out_cfg = Config.fromstring(cfg_str, file_format)
assert in_cfg._cfg_dict == out_cfg._cfg_dict
# test pretty_text only supports py file format
cfg_file = osp.join(data_path, 'config', 'b.json')
in_cfg = Config.fromfile(cfg_file)
with pytest.raises(Exception):
Config.fromstring(in_cfg.pretty_text, '.json')
# test file format error
cfg_str = open(cfg_file, 'r').read()
with pytest.raises(Exception):
Config.fromstring(cfg_str, '.py')
def test_merge_from_base(): def test_merge_from_base():
cfg_file = osp.join(data_path, 'config/d.py') cfg_file = osp.join(data_path, 'config/d.py')
cfg = Config.fromfile(cfg_file) cfg = Config.fromfile(cfg_file)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment