Unverified Commit 987cb583 authored by Jerry Jiarui XU's avatar Jerry Jiarui XU Committed by GitHub
Browse files

[Enhance] Refactor logger (#659)

* [Enhance] Refactor logger

* fixed test

* make commit optional

* remove debug info

* fixed test
parent dfa36dfe
...@@ -89,13 +89,6 @@ class LoggerHook(Hook): ...@@ -89,13 +89,6 @@ class LoggerHook(Hook):
current_iter = runner.iter + 1 current_iter = runner.iter + 1
return current_iter return current_iter
def get_step(self, runner):
"""Get the total training step/epoch."""
if self.get_mode(runner) == 'val' and self.by_epoch:
return self.get_epoch(runner)
else:
return self.get_iter(runner)
def get_lr_tags(self, runner): def get_lr_tags(self, runner):
tags = {} tags = {}
lrs = runner.current_lr() lrs = runner.current_lr()
......
...@@ -69,7 +69,7 @@ class MlflowLoggerHook(LoggerHook): ...@@ -69,7 +69,7 @@ class MlflowLoggerHook(LoggerHook):
def log(self, runner): def log(self, runner):
tags = self.get_loggable_tags(runner) tags = self.get_loggable_tags(runner)
if tags: if tags:
self.mlflow.log_metrics(tags, step=self.get_step(runner)) self.mlflow.log_metrics(tags, step=self.get_iter(runner))
@master_only @master_only
def after_run(self, runner): def after_run(self, runner):
......
...@@ -69,6 +69,13 @@ class PaviLoggerHook(LoggerHook): ...@@ -69,6 +69,13 @@ class PaviLoggerHook(LoggerHook):
if self.add_graph: if self.add_graph:
self.writer.add_graph(runner.model) self.writer.add_graph(runner.model)
def get_step(self, runner):
"""Get the total training step/epoch."""
if self.get_mode(runner) == 'val' and self.by_epoch:
return self.get_epoch(runner)
else:
return self.get_iter(runner)
@master_only @master_only
def log(self, runner): def log(self, runner):
tags = self.get_loggable_tags(runner, add_mode=False) tags = self.get_loggable_tags(runner, add_mode=False)
......
...@@ -46,9 +46,9 @@ class TensorboardLoggerHook(LoggerHook): ...@@ -46,9 +46,9 @@ class TensorboardLoggerHook(LoggerHook):
tags = self.get_loggable_tags(runner, allow_text=True) tags = self.get_loggable_tags(runner, allow_text=True)
for tag, val in tags.items(): for tag, val in tags.items():
if isinstance(val, str): if isinstance(val, str):
self.writer.add_text(tag, val, self.get_step(runner)) self.writer.add_text(tag, val, self.get_iter(runner))
else: else:
self.writer.add_scalar(tag, val, self.get_step(runner)) self.writer.add_scalar(tag, val, self.get_iter(runner))
@master_only @master_only
def after_run(self, runner): def after_run(self, runner):
......
...@@ -12,11 +12,13 @@ class WandbLoggerHook(LoggerHook): ...@@ -12,11 +12,13 @@ class WandbLoggerHook(LoggerHook):
interval=10, interval=10,
ignore_last=True, ignore_last=True,
reset_flag=True, reset_flag=True,
commit=True,
by_epoch=True): by_epoch=True):
super(WandbLoggerHook, self).__init__(interval, ignore_last, super(WandbLoggerHook, self).__init__(interval, ignore_last,
reset_flag, by_epoch) reset_flag, by_epoch)
self.import_wandb() self.import_wandb()
self.init_kwargs = init_kwargs self.init_kwargs = init_kwargs
self.commit = commit
def import_wandb(self): def import_wandb(self):
try: try:
...@@ -39,7 +41,8 @@ class WandbLoggerHook(LoggerHook): ...@@ -39,7 +41,8 @@ class WandbLoggerHook(LoggerHook):
def log(self, runner): def log(self, runner):
tags = self.get_loggable_tags(runner) tags = self.get_loggable_tags(runner)
if tags: if tags:
self.wandb.log(tags, step=self.get_step(runner)) self.wandb.log(
tags, step=self.get_iter(runner), commit=self.commit)
@master_only @master_only
def after_run(self, runner): def after_run(self, runner):
......
...@@ -346,7 +346,7 @@ def test_mlflow_hook(log_model): ...@@ -346,7 +346,7 @@ def test_mlflow_hook(log_model):
{ {
'learning_rate': 0.02, 'learning_rate': 0.02,
'momentum': 0.95 'momentum': 0.95
}, step=1) }, step=6)
if log_model: if log_model:
hook.mlflow_pytorch.log_model.assert_called_with( hook.mlflow_pytorch.log_model.assert_called_with(
runner.model, 'models') runner.model, 'models')
...@@ -369,7 +369,8 @@ def test_wandb_hook(): ...@@ -369,7 +369,8 @@ def test_wandb_hook():
'learning_rate': 0.02, 'learning_rate': 0.02,
'momentum': 0.95 'momentum': 0.95
}, },
step=1) step=6,
commit=True)
hook.wandb.join.assert_called_with() hook.wandb.join.assert_called_with()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment