Unverified Commit 0dfc6ae6 authored by Rui Xu's avatar Rui Xu Committed by GitHub
Browse files

[bug] fix bug in resuming optim-dict when using epoch based runner (#610)

* fix bug in resuming optim-dict

* add raise error when meeting unexceptable type

* add raise error when meeting unexceptable type
parent 665fee24
...@@ -330,7 +330,16 @@ class BaseRunner(metaclass=ABCMeta): ...@@ -330,7 +330,16 @@ class BaseRunner(metaclass=ABCMeta):
self._epoch = checkpoint['meta']['epoch'] self._epoch = checkpoint['meta']['epoch']
self._iter = checkpoint['meta']['iter'] self._iter = checkpoint['meta']['iter']
if 'optimizer' in checkpoint and resume_optimizer: if 'optimizer' in checkpoint and resume_optimizer:
self.optimizer.load_state_dict(checkpoint['optimizer']) if isinstance(self.optimizer, Optimizer):
self.optimizer.load_state_dict(checkpoint['optimizer'])
elif isinstance(self.optimizer, dict):
for k in self.optimizer.keys():
self.optimizer[k].load_state_dict(
checkpoint['optimizer'][k])
else:
raise TypeError(
'Optimizer should be dict or torch.optim.Optimizer '
f'but got {type(self.optimizer)}')
self.logger.info('resumed epoch %d, iter %d', self.epoch, self.iter) self.logger.info('resumed epoch %d, iter %d', self.epoch, self.iter)
......
...@@ -165,6 +165,10 @@ class IterBasedRunner(BaseRunner): ...@@ -165,6 +165,10 @@ class IterBasedRunner(BaseRunner):
for k in self.optimizer.keys(): for k in self.optimizer.keys():
self.optimizer[k].load_state_dict( self.optimizer[k].load_state_dict(
checkpoint['optimizer'][k]) checkpoint['optimizer'][k])
else:
raise TypeError(
'Optimizer should be dict or torch.optim.Optimizer '
f'but got {type(self.optimizer)}')
self.logger.info(f'resumed from epoch: {self.epoch}, iter {self.iter}') self.logger.info(f'resumed from epoch: {self.epoch}, iter {self.iter}')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment