"...v2/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "a06df0d9229e49bd859e2ff0355a48b3bddd1e10"
Commit 507e67e4 authored by Kai Chen's avatar Kai Chen
Browse files

make optimizer an optional argument for Runner

parent 7d872508
...@@ -20,13 +20,16 @@ class Runner(object): ...@@ -20,13 +20,16 @@ class Runner(object):
def __init__(self, def __init__(self,
model, model,
optimizer,
batch_processor, batch_processor,
optimizer=None,
work_dir=None, work_dir=None,
log_level=logging.INFO): log_level=logging.INFO):
assert callable(batch_processor) assert callable(batch_processor)
self.model = model self.model = model
self.optimizer = self.init_optimizer(optimizer) if optimizer is not None:
self.optimizer = self.init_optimizer(optimizer)
else:
self.optimizer = None
self.batch_processor = batch_processor self.batch_processor = batch_processor
# create work_dir # create work_dir
...@@ -152,6 +155,9 @@ class Runner(object): ...@@ -152,6 +155,9 @@ class Runner(object):
Returns: Returns:
list: Current learning rate of all param groups. list: Current learning rate of all param groups.
""" """
if self.optimizer is None:
raise RuntimeError(
'lr is not applicable because optimizer does not exist.')
return [group['lr'] for group in self.optimizer.param_groups] return [group['lr'] for group in self.optimizer.param_groups]
def register_hook(self, hook, priority=50): def register_hook(self, hook, priority=50):
...@@ -234,8 +240,9 @@ class Runner(object): ...@@ -234,8 +240,9 @@ class Runner(object):
for i, data_batch in enumerate(data_loader): for i, data_batch in enumerate(data_loader):
self._inner_iter = i self._inner_iter = i
self.call_hook('before_val_iter') self.call_hook('before_val_iter')
outputs = self.batch_processor( with torch.no_grad():
self.model, data_batch, train_mode=False, **kwargs) outputs = self.batch_processor(
self.model, data_batch, train_mode=False, **kwargs)
if not isinstance(outputs, dict): if not isinstance(outputs, dict):
raise TypeError('batch_processor() must return a dict') raise TypeError('batch_processor() must return a dict')
if 'log_vars' in outputs: if 'log_vars' in outputs:
...@@ -321,12 +328,12 @@ class Runner(object): ...@@ -321,12 +328,12 @@ class Runner(object):
info, hooks, default_args=dict(interval=log_interval)) info, hooks, default_args=dict(interval=log_interval))
self.register_hook(logger_hook, priority=60) self.register_hook(logger_hook, priority=60)
def register_default_hooks(self, def register_training_hooks(self,
lr_config, lr_config,
grad_clip_config=None, grad_clip_config=None,
checkpoint_config=None, checkpoint_config=None,
log_config=None): log_config=None):
"""Register several default hooks. """Register default hooks for training.
Default hooks include: Default hooks include:
- LrUpdaterHook - LrUpdaterHook
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment