Commit 507e67e4 authored by Kai Chen's avatar Kai Chen
Browse files

make optimizer an optional argument for Runner

parent 7d872508
......@@ -20,13 +20,16 @@ class Runner(object):
def __init__(self,
model,
optimizer,
batch_processor,
optimizer=None,
work_dir=None,
log_level=logging.INFO):
assert callable(batch_processor)
self.model = model
self.optimizer = self.init_optimizer(optimizer)
if optimizer is not None:
self.optimizer = self.init_optimizer(optimizer)
else:
self.optimizer = None
self.batch_processor = batch_processor
# create work_dir
......@@ -152,6 +155,9 @@ class Runner(object):
Returns:
list: Current learning rate of all param groups.
"""
if self.optimizer is None:
raise RuntimeError(
'lr is not applicable because optimizer does not exist.')
return [group['lr'] for group in self.optimizer.param_groups]
def register_hook(self, hook, priority=50):
......@@ -234,8 +240,9 @@ class Runner(object):
for i, data_batch in enumerate(data_loader):
self._inner_iter = i
self.call_hook('before_val_iter')
outputs = self.batch_processor(
self.model, data_batch, train_mode=False, **kwargs)
with torch.no_grad():
outputs = self.batch_processor(
self.model, data_batch, train_mode=False, **kwargs)
if not isinstance(outputs, dict):
raise TypeError('batch_processor() must return a dict')
if 'log_vars' in outputs:
......@@ -321,12 +328,12 @@ class Runner(object):
info, hooks, default_args=dict(interval=log_interval))
self.register_hook(logger_hook, priority=60)
def register_default_hooks(self,
lr_config,
grad_clip_config=None,
checkpoint_config=None,
log_config=None):
"""Register several default hooks.
def register_training_hooks(self,
lr_config,
grad_clip_config=None,
checkpoint_config=None,
log_config=None):
"""Register default hooks for training.
Default hooks include:
- LrUpdaterHook
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment