Unverified Commit c8435966 authored by su's avatar su Committed by GitHub
Browse files

Reordered the hooks and use attributes rather than args. (#544)

* Reordered the hooks and use attributes rather than args.

Formated.

* Reordering may cause conflict, assign the value first than update such as max_iter

Rewind back the order.

* Rewind back to just use attributes, the update of max_iter and stuff will be done in new hooks.

Minor format.
parent a59a35bc
...@@ -99,10 +99,11 @@ class EpochBasedRunner(BaseRunner): ...@@ -99,10 +99,11 @@ class EpochBasedRunner(BaseRunner):
work_dir = self.work_dir if self.work_dir is not None else 'NONE' work_dir = self.work_dir if self.work_dir is not None else 'NONE'
self.logger.info('Start running, host: %s, work_dir: %s', self.logger.info('Start running, host: %s, work_dir: %s',
get_host_info(), work_dir) get_host_info(), work_dir)
self.logger.info('workflow: %s, max: %d epochs', workflow, max_epochs) self.logger.info('workflow: %s, max: %d epochs', workflow,
self._max_epochs)
self.call_hook('before_run') self.call_hook('before_run')
while self.epoch < max_epochs: while self.epoch < self._max_epochs:
for i, flow in enumerate(workflow): for i, flow in enumerate(workflow):
mode, epochs = flow mode, epochs = flow
if isinstance(mode, str): # self.train() if isinstance(mode, str): # self.train()
...@@ -117,7 +118,7 @@ class EpochBasedRunner(BaseRunner): ...@@ -117,7 +118,7 @@ class EpochBasedRunner(BaseRunner):
type(mode))) type(mode)))
for _ in range(epochs): for _ in range(epochs):
if mode == 'train' and self.epoch >= max_epochs: if mode == 'train' and self.epoch >= self._max_epochs:
break break
epoch_runner(data_loaders[i], **kwargs) epoch_runner(data_loaders[i], **kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment