Unverified Commit 50af0099 authored by Zhiyuan Chen's avatar Zhiyuan Chen Committed by GitHub
Browse files

merge the calling of train/val_step and batch_processor into run_iter (#553)

* merge train/val_step and batch_processor into run_iter

* make self.train_mode of runner as an argument

* remove abstract methods of run_iter() in base_runner
parent 8c70df33
......@@ -19,6 +19,22 @@ class EpochBasedRunner(BaseRunner):
This runner train models epoch by epoch.
"""
def run_iter(self, data_batch, train_mode, **kwargs):
if self.batch_processor is not None:
outputs = self.batch_processor(
self.model, data_batch, train_mode=train_mode, **kwargs)
elif train_mode:
outputs = self.model.train_step(data_batch, self.optimizer,
**kwargs)
else:
outputs = self.model.val_step(data_batch, self.optimizer, **kwargs)
if not isinstance(outputs, dict):
raise TypeError('"batch_processor()" or "model.train_step()"'
'and "model.val_step()" must return a dict')
if 'log_vars' in outputs:
self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
self.outputs = outputs
def train(self, data_loader, **kwargs):
self.model.train()
self.mode = 'train'
......@@ -29,19 +45,7 @@ class EpochBasedRunner(BaseRunner):
for i, data_batch in enumerate(self.data_loader):
self._inner_iter = i
self.call_hook('before_train_iter')
if self.batch_processor is None:
outputs = self.model.train_step(data_batch, self.optimizer,
**kwargs)
else:
outputs = self.batch_processor(
self.model, data_batch, train_mode=True, **kwargs)
if not isinstance(outputs, dict):
raise TypeError('"batch_processor()" or "model.train_step()"'
' must return a dict')
if 'log_vars' in outputs:
self.log_buffer.update(outputs['log_vars'],
outputs['num_samples'])
self.outputs = outputs
self.run_iter(data_batch, train_mode=True)
self.call_hook('after_train_iter')
self._iter += 1
......@@ -58,19 +62,7 @@ class EpochBasedRunner(BaseRunner):
self._inner_iter = i
self.call_hook('before_val_iter')
with torch.no_grad():
if self.batch_processor is None:
outputs = self.model.val_step(data_batch, self.optimizer,
**kwargs)
else:
outputs = self.batch_processor(
self.model, data_batch, train_mode=False, **kwargs)
if not isinstance(outputs, dict):
raise TypeError('"batch_processor()" or "model.val_step()"'
' must return a dict')
if 'log_vars' in outputs:
self.log_buffer.update(outputs['log_vars'],
outputs['num_samples'])
self.outputs = outputs
self.run_iter(data_batch, train_mode=False)
self.call_hook('after_val_iter')
self.call_hook('after_val_epoch')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment