Commit c2b397b1 authored by Sam Tsai's avatar Sam Tsai Committed by Facebook GitHub Bot
Browse files

make tensorboardx logging overridable.

Summary:
Add get_tbx_writer to runner class and call that in the do_train. Make tbx writer overridable.

(see D31289763 for a use case)

Reviewed By: zhanghang1989

Differential Revision: D31289763

fbshipit-source-id: 19ddbbe8df62f9da0640f595532cd8f1296e3be8
parent a9dce74e
...@@ -282,7 +282,7 @@ class Detectron2GoRunner(BaseRunner): ...@@ -282,7 +282,7 @@ class Detectron2GoRunner(BaseRunner):
if comm.is_main_process(): if comm.is_main_process():
if hasattr(model, "_visualize_model"): if hasattr(model, "_visualize_model"):
logger.info("Adding model visualization ...") logger.info("Adding model visualization ...")
tbx_writer = _get_tbx_writer(get_tensorboard_log_dir(cfg.OUTPUT_DIR)) tbx_writer = self.get_tbx_writer(cfg)
model._visualize_model(tbx_writer) model._visualize_model(tbx_writer)
return model return model
...@@ -298,6 +298,10 @@ class Detectron2GoRunner(BaseRunner): ...@@ -298,6 +298,10 @@ class Detectron2GoRunner(BaseRunner):
def build_lr_scheduler(self, cfg, optimizer): def build_lr_scheduler(self, cfg, optimizer):
return d2_build_lr_scheduler(cfg, optimizer) return d2_build_lr_scheduler(cfg, optimizer)
@classmethod
def get_tbx_writer(cls, cfg):
return _get_tbx_writer(get_tensorboard_log_dir(cfg.OUTPUT_DIR))
def _do_test(self, cfg, model, train_iter=None, model_tag="default"): def _do_test(self, cfg, model, train_iter=None, model_tag="default"):
"""train_iter: Current iteration of the model, None means final iteration""" """train_iter: Current iteration of the model, None means final iteration"""
assert len(cfg.DATASETS.TEST) assert len(cfg.DATASETS.TEST)
...@@ -337,7 +341,7 @@ class Detectron2GoRunner(BaseRunner): ...@@ -337,7 +341,7 @@ class Detectron2GoRunner(BaseRunner):
if not isinstance(evaluator, DatasetEvaluators): if not isinstance(evaluator, DatasetEvaluators):
evaluator = DatasetEvaluators([evaluator]) evaluator = DatasetEvaluators([evaluator])
if comm.is_main_process(): if comm.is_main_process():
tbx_writer = _get_tbx_writer(get_tensorboard_log_dir(cfg.OUTPUT_DIR)) tbx_writer = self.get_tbx_writer(cfg)
logger.info("Adding visualization evaluator ...") logger.info("Adding visualization evaluator ...")
mapper = self.get_mapper(cfg, is_train=False) mapper = self.get_mapper(cfg, is_train=False)
vis_eval_type = self.get_visualization_evaluator() vis_eval_type = self.get_visualization_evaluator()
...@@ -388,11 +392,11 @@ class Detectron2GoRunner(BaseRunner): ...@@ -388,11 +392,11 @@ class Detectron2GoRunner(BaseRunner):
flattened_results = flatten_results_dict(results) flattened_results = flatten_results_dict(results)
for k, v in flattened_results.items(): for k, v in flattened_results.items():
tbx_writer = _get_tbx_writer(get_tensorboard_log_dir(cfg.OUTPUT_DIR)) tbx_writer = self.get_tbx_writer(cfg)
tbx_writer._writer.add_scalar("eval_{}".format(k), v, train_iter) tbx_writer._writer.add_scalar("eval_{}".format(k), v, train_iter)
if comm.is_main_process(): if comm.is_main_process():
tbx_writer = _get_tbx_writer(get_tensorboard_log_dir(cfg.OUTPUT_DIR)) tbx_writer = self.get_tbx_writer(cfg)
tbx_writer._writer.flush() tbx_writer._writer.flush()
return results return results
...@@ -455,7 +459,7 @@ class Detectron2GoRunner(BaseRunner): ...@@ -455,7 +459,7 @@ class Detectron2GoRunner(BaseRunner):
if not cfg.ABNORMAL_CHECKER.ENABLED: if not cfg.ABNORMAL_CHECKER.ENABLED:
return model return model
tbx_writer = _get_tbx_writer(get_tensorboard_log_dir(cfg.OUTPUT_DIR)) tbx_writer = self.get_tbx_writer(cfg)
writers = abnormal_checker.get_writers(cfg, tbx_writer) writers = abnormal_checker.get_writers(cfg, tbx_writer)
checker = abnormal_checker.AbnormalLossChecker(start_iter, writers) checker = abnormal_checker.AbnormalLossChecker(start_iter, writers)
ret = abnormal_checker.AbnormalLossCheckerWrapper(model, checker) ret = abnormal_checker.AbnormalLossCheckerWrapper(model, checker)
...@@ -480,7 +484,7 @@ class Detectron2GoRunner(BaseRunner): ...@@ -480,7 +484,7 @@ class Detectron2GoRunner(BaseRunner):
] ]
if comm.is_main_process(): if comm.is_main_process():
tbx_writer = _get_tbx_writer(get_tensorboard_log_dir(cfg.OUTPUT_DIR)) tbx_writer = self.get_tbx_writer(cfg)
writers = [ writers = [
CommonMetricPrinter(max_iter), CommonMetricPrinter(max_iter),
JSONWriter(os.path.join(cfg.OUTPUT_DIR, "metrics.json")), JSONWriter(os.path.join(cfg.OUTPUT_DIR, "metrics.json")),
...@@ -570,7 +574,7 @@ class Detectron2GoRunner(BaseRunner): ...@@ -570,7 +574,7 @@ class Detectron2GoRunner(BaseRunner):
if comm.is_main_process(): if comm.is_main_process():
data_loader_type = cls.get_data_loader_vis_wrapper() data_loader_type = cls.get_data_loader_vis_wrapper()
if data_loader_type is not None: if data_loader_type is not None:
tbx_writer = _get_tbx_writer(get_tensorboard_log_dir(cfg.OUTPUT_DIR)) tbx_writer = cls.get_tbx_writer(cfg)
data_loader = data_loader_type(cfg, tbx_writer, data_loader) data_loader = data_loader_type(cfg, tbx_writer, data_loader)
return data_loader return data_loader
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment