Unverified Commit 642b48f1 authored by densechen's avatar densechen Committed by GitHub
Browse files

RunnerConstructor (#1296)



* runner constructor

* import runner at `__init__`

* fix yapf

* fix

* fix yapf

* better write

* add using example

* add common

* fix lint

* refactor format
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
parent 979a355d
......@@ -5,6 +5,7 @@ from .builder import RUNNERS, build_runner
from .checkpoint import (CheckpointLoader, _load_checkpoint,
_load_checkpoint_with_prefix, load_checkpoint,
load_state_dict, save_checkpoint, weights_to_cpu)
from .default_constructor import DefaultRunnerConstructor
from .dist_utils import (allreduce_grads, allreduce_params, get_dist_info,
init_dist, master_only)
from .epoch_based_runner import EpochBasedRunner, Runner
......@@ -42,5 +43,5 @@ __all__ = [
'allreduce_params', 'LossScaler', 'CheckpointLoader', 'BaseModule',
'_load_checkpoint_with_prefix', 'EvalHook', 'DistEvalHook', 'Sequential',
'ModuleList', 'GradientCumulativeOptimizerHook',
'GradientCumulativeFp16OptimizerHook'
'GradientCumulativeFp16OptimizerHook', 'DefaultRunnerConstructor'
]
# Copyright (c) OpenMMLab. All rights reserved.
from ..utils import Registry, build_from_cfg
import copy
from ..utils import Registry
RUNNERS = Registry('runner')
RUNNER_BUILDERS = Registry('runner builder')
def build_runner_constructor(cfg):
return RUNNER_BUILDERS.build(cfg)
def build_runner(cfg, default_args=None):
return build_from_cfg(cfg, RUNNERS, default_args=default_args)
runner_cfg = copy.deepcopy(cfg)
constructor_type = runner_cfg.pop('constructor',
'DefaultRunnerConstructor')
runner_constructor = build_runner_constructor(
dict(
type=constructor_type,
runner_cfg=runner_cfg,
default_args=default_args))
runner = runner_constructor()
return runner
from .builder import RUNNER_BUILDERS, RUNNERS
@RUNNER_BUILDERS.register_module()
class DefaultRunnerConstructor:
"""Default constructor for runners.
Custom existing `Runner` like `EpocBasedRunner` though `RunnerConstructor`.
For example, We can inject some new properties and functions for `Runner`.
Example:
>>> from mmcv.runner import RUNNER_BUILDERS, build_runner
>>> # Define a new RunnerReconstructor
>>> @RUNNER_BUILDERS.register_module()
>>> class MyRunnerConstructor:
... def __init__(self, runner_cfg, default_args=None):
... if not isinstance(runner_cfg, dict):
... raise TypeError('runner_cfg should be a dict',
... f'but got {type(runner_cfg)}')
... self.runner_cfg = runner_cfg
... self.default_args = default_args
...
... def __call__(self):
... runner = RUNNERS.build(self.runner_cfg,
... default_args=self.default_args)
... # Add new properties for existing runner
... runner.my_name = 'my_runner'
... runner.my_function = lambda self: print(self.my_name)
... ...
>>> # build your runner
>>> runner_cfg = dict(type='EpochBasedRunner', max_epochs=40,
... constructor='MyRunnerConstructor')
>>> runner = build_runner(runner_cfg)
"""
def __init__(self, runner_cfg, default_args=None):
if not isinstance(runner_cfg, dict):
raise TypeError('runner_cfg should be a dict',
f'but got {type(runner_cfg)}')
self.runner_cfg = runner_cfg
self.default_args = default_args
def __call__(self):
return RUNNERS.build(self.runner_cfg, default_args=self.default_args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment