Commit b21607b1 authored by Tao Xu's avatar Tao Xu Committed by Facebook GitHub Bot
Browse files

fix the issue of tensorboard visualization

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/473

As shown in the attached image and tb visualization, some of our jobs fail to save the results to tensorboard.
There should be some messages between circled lines of the screenshot if the images are added to tensorboard.
One possible reason is that the tensorbord visualization evaluator is only added for the rank 0 gpu. It may fail to fetch any data during evaluation of diffusion model which only do 1 batch of inference during validataion.
To resolve this issue, we add the visualization evaluator to all ranks of gpus and gather their results, and finally add the results with biggest batchsize to the tensorboard for visualization.

The screenshot is from f410204704 (https://www.internalfb.com/manifold/explorer/mobile_vision_workflows/tree/workflows/xutao/20230211/latest_train/dalle2_decoder.SIULDLpgix/e2e_train/log.txt)

Refactored the default_runner.py to have a new function _create_evaluators for create all evaluators. Thus we do not need to override the whole _do_test function in the runner which need to add the visualization evaluator of all ranks.

(Note: this ignores all push blocking failures!)

Reviewed By: YanjunChen329

Differential Revision: D43263543

fbshipit-source-id: eca2259277584819dcc5400d47fa4fb142f2ed9b
parent 31197c3e
...@@ -328,6 +328,32 @@ class Detectron2GoRunner(D2GoDataAPIMixIn, BaseRunner): ...@@ -328,6 +328,32 @@ class Detectron2GoRunner(D2GoDataAPIMixIn, BaseRunner):
def build_lr_scheduler(self, cfg, optimizer): def build_lr_scheduler(self, cfg, optimizer):
return d2_build_lr_scheduler(cfg, optimizer) return d2_build_lr_scheduler(cfg, optimizer)
def _create_evaluators(
self, cfg, dataset_name, output_folder, train_iter, model_tag
):
evaluator = self.get_evaluator(cfg, dataset_name, output_folder=output_folder)
if not isinstance(evaluator, DatasetEvaluators):
evaluator = DatasetEvaluators([evaluator])
if comm.is_main_process():
# Add evaluator for visualization only to rank 0
tbx_writer = self.get_tbx_writer(cfg)
logger.info("Adding visualization evaluator ...")
mapper = self.get_mapper(cfg, is_train=False)
vis_eval_type = self.get_visualization_evaluator()
if vis_eval_type is not None:
evaluator._evaluators.append(
vis_eval_type(
cfg,
tbx_writer,
mapper,
dataset_name,
train_iter=train_iter,
tag_postfix=model_tag,
)
)
return evaluator
def _do_test(self, cfg, model, train_iter=None, model_tag="default"): def _do_test(self, cfg, model, train_iter=None, model_tag="default"):
"""train_iter: Current iteration of the model, None means final iteration""" """train_iter: Current iteration of the model, None means final iteration"""
assert len(cfg.DATASETS.TEST) assert len(cfg.DATASETS.TEST)
...@@ -361,29 +387,10 @@ class Detectron2GoRunner(D2GoDataAPIMixIn, BaseRunner): ...@@ -361,29 +387,10 @@ class Detectron2GoRunner(D2GoDataAPIMixIn, BaseRunner):
# NOTE: creating evaluator after dataset is loaded as there might be dependency. # noqa # NOTE: creating evaluator after dataset is loaded as there might be dependency. # noqa
data_loader = self.build_detection_test_loader(cfg, dataset_name) data_loader = self.build_detection_test_loader(cfg, dataset_name)
evaluator = self.get_evaluator( evaluator = self._create_evaluators(
cfg, dataset_name, output_folder=output_folder cfg, dataset_name, output_folder, train_iter, model_tag
) )
if not isinstance(evaluator, DatasetEvaluators):
evaluator = DatasetEvaluators([evaluator])
if comm.is_main_process():
tbx_writer = self.get_tbx_writer(cfg)
logger.info("Adding visualization evaluator ...")
mapper = self.get_mapper(cfg, is_train=False)
vis_eval_type = self.get_visualization_evaluator()
if vis_eval_type is not None:
evaluator._evaluators.append(
vis_eval_type(
cfg,
tbx_writer,
mapper,
dataset_name,
train_iter=train_iter,
tag_postfix=model_tag,
)
)
results_per_dataset = inference_on_dataset(model, data_loader, evaluator) results_per_dataset = inference_on_dataset(model, data_loader, evaluator)
if comm.is_main_process(): if comm.is_main_process():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment