Commit daf7f294 authored by Peizhao Zhang's avatar Peizhao Zhang Committed by Facebook GitHub Bot
Browse files

Make data and evaluation visualization optional.

Summary:
Make data and evaluation visualization optional.
* could return None.

Reviewed By: zhanghang1989, wat3rBro

Differential Revision: D27316632

fbshipit-source-id: 2a85db4815cbf3407a20a74c125dcd52d75167fa
parent 66df06ef
...@@ -8,7 +8,7 @@ import math ...@@ -8,7 +8,7 @@ import math
import os import os
from collections import OrderedDict from collections import OrderedDict
from functools import lru_cache, partial from functools import lru_cache, partial
from typing import Type from typing import Type, Optional
import d2go.utils.abnormal_checker as abnormal_checker import d2go.utils.abnormal_checker as abnormal_checker
import detectron2.utils.comm as comm import detectron2.utils.comm as comm
...@@ -317,16 +317,18 @@ class Detectron2GoRunner(BaseRunner): ...@@ -317,16 +317,18 @@ class Detectron2GoRunner(BaseRunner):
tbx_writer = _get_tbx_writer(get_tensorboard_log_dir(cfg.OUTPUT_DIR)) tbx_writer = _get_tbx_writer(get_tensorboard_log_dir(cfg.OUTPUT_DIR))
logger.info("Adding visualization evaluator ...") logger.info("Adding visualization evaluator ...")
mapper = self.get_mapper(cfg, is_train=False) mapper = self.get_mapper(cfg, is_train=False)
evaluator._evaluators.append( vis_eval_type = self.get_visualization_evaluator()
self.get_visualization_evaluator()( if vis_eval_type is not None:
cfg, evaluator._evaluators.append(
tbx_writer, vis_eval_type(
mapper, cfg,
dataset_name, tbx_writer,
train_iter=train_iter, mapper,
tag_postfix=model_tag, dataset_name,
train_iter=train_iter,
tag_postfix=model_tag,
)
) )
)
results_per_dataset = inference_on_dataset(model, data_loader, evaluator) results_per_dataset = inference_on_dataset(model, data_loader, evaluator)
...@@ -485,14 +487,14 @@ class Detectron2GoRunner(BaseRunner): ...@@ -485,14 +487,14 @@ class Detectron2GoRunner(BaseRunner):
data_loader = build_d2go_train_loader(cfg, mapper) data_loader = build_d2go_train_loader(cfg, mapper)
if comm.is_main_process(): if comm.is_main_process():
tbx_writer = _get_tbx_writer(get_tensorboard_log_dir(cfg.OUTPUT_DIR)) data_loader_type = cls.get_data_loader_vis_wrapper()
data_loader = cls.get_data_loader_vis_wrapper()( if data_loader_type is not None:
cfg, tbx_writer, data_loader tbx_writer = _get_tbx_writer(get_tensorboard_log_dir(cfg.OUTPUT_DIR))
) data_loader = data_loader_type(cfg, tbx_writer, data_loader)
return data_loader return data_loader
@staticmethod @staticmethod
def get_data_loader_vis_wrapper() -> Type[DataLoaderVisWrapper]: def get_data_loader_vis_wrapper() -> Optional[Type[DataLoaderVisWrapper]]:
return DataLoaderVisWrapper return DataLoaderVisWrapper
@staticmethod @staticmethod
...@@ -524,7 +526,7 @@ class Detectron2GoRunner(BaseRunner): ...@@ -524,7 +526,7 @@ class Detectron2GoRunner(BaseRunner):
return mapper return mapper
@staticmethod @staticmethod
def get_visualization_evaluator() -> Type[VisualizationEvaluator]: def get_visualization_evaluator() -> Optional[Type[VisualizationEvaluator]]:
return VisualizationEvaluator return VisualizationEvaluator
@staticmethod @staticmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment