Unverified Commit e417035f authored by Ivan Shcheklein's avatar Ivan Shcheklein Committed by GitHub
Browse files

[Enhancement] Add ability to pass logger instance to frameworks (#2317)



* Add ability to pass logger instance to frameworks

* refine docstring

* Update mmcv/runner/hooks/logger/dvclive.py
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
parent dd5b415d
...@@ -22,7 +22,10 @@ class DvcliveLoggerHook(LoggerHook): ...@@ -22,7 +22,10 @@ class DvcliveLoggerHook(LoggerHook):
reset_flag (bool): Whether to clear the output buffer after logging. reset_flag (bool): Whether to clear the output buffer after logging.
Default: False. Default: False.
by_epoch (bool): Whether EpochBasedRunner is used. Default: True. by_epoch (bool): Whether EpochBasedRunner is used. Default: True.
kwargs: Arguments for instantiating `Live`_. dvclive (Live, optional): An instance of the `Live`_ logger to use
instead of initializing a new one internally. Defaults to None.
kwargs: Arguments for instantiating `Live`_ (ignored if `dvclive` is
provided).
.. _dvclive: .. _dvclive:
https://dvc.org/doc/dvclive https://dvc.org/doc/dvclive
...@@ -37,18 +40,19 @@ class DvcliveLoggerHook(LoggerHook): ...@@ -37,18 +40,19 @@ class DvcliveLoggerHook(LoggerHook):
ignore_last: bool = True, ignore_last: bool = True,
reset_flag: bool = False, reset_flag: bool = False,
by_epoch: bool = True, by_epoch: bool = True,
dvclive=None,
**kwargs): **kwargs):
super().__init__(interval, ignore_last, reset_flag, by_epoch) super().__init__(interval, ignore_last, reset_flag, by_epoch)
self.model_file = model_file self.model_file = model_file
self.import_dvclive(**kwargs) self._import_dvclive(dvclive, **kwargs)
def import_dvclive(self, **kwargs) -> None: def _import_dvclive(self, dvclive=None, **kwargs) -> None:
try: try:
from dvclive import Live from dvclive import Live
except ImportError: except ImportError:
raise ImportError( raise ImportError(
'Please run "pip install dvclive" to install dvclive') 'Please run "pip install dvclive" to install dvclive')
self.dvclive = Live(**kwargs) self.dvclive = dvclive if dvclive is not None else Live(**kwargs)
@master_only @master_only
def log(self, runner) -> None: def log(self, runner) -> None:
......
...@@ -1665,7 +1665,6 @@ def test_dvclive_hook_model_file(tmp_path): ...@@ -1665,7 +1665,6 @@ def test_dvclive_hook_model_file(tmp_path):
hook = DvcliveLoggerHook(model_file=osp.join(runner.work_dir, 'model.pth')) hook = DvcliveLoggerHook(model_file=osp.join(runner.work_dir, 'model.pth'))
runner.register_hook(hook) runner.register_hook(hook)
loader = torch.utils.data.DataLoader(torch.ones((5, 2)))
loader = DataLoader(torch.ones((5, 2))) loader = DataLoader(torch.ones((5, 2)))
runner.run([loader, loader], [('train', 1), ('val', 1)]) runner.run([loader, loader], [('train', 1), ('val', 1)])
...@@ -1675,6 +1674,16 @@ def test_dvclive_hook_model_file(tmp_path): ...@@ -1675,6 +1674,16 @@ def test_dvclive_hook_model_file(tmp_path):
shutil.rmtree(runner.work_dir) shutil.rmtree(runner.work_dir)
def test_dvclive_hook_pass_logger(tmp_path):
sys.modules['dvclive'] = MagicMock()
from dvclive import Live
logger = Live()
sys.modules['dvclive'] = MagicMock()
assert DvcliveLoggerHook().dvclive is not logger
assert DvcliveLoggerHook(dvclive=logger).dvclive is logger
def test_clearml_hook(): def test_clearml_hook():
sys.modules['clearml'] = MagicMock() sys.modules['clearml'] = MagicMock()
runner = _build_demo_runner() runner = _build_demo_runner()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment