Commit daf7f294 authored by Peizhao Zhang's avatar Peizhao Zhang Committed by Facebook GitHub Bot
Browse files

Make data and evaluation visualization optional.

Summary:
Make data and evaluation visualization optional.
* could return None.

Reviewed By: zhanghang1989, wat3rBro

Differential Revision: D27316632

fbshipit-source-id: 2a85db4815cbf3407a20a74c125dcd52d75167fa
parent 66df06ef
......@@ -8,7 +8,7 @@ import math
import os
from collections import OrderedDict
from functools import lru_cache, partial
from typing import Type
from typing import Type, Optional
import d2go.utils.abnormal_checker as abnormal_checker
import detectron2.utils.comm as comm
......@@ -317,16 +317,18 @@ class Detectron2GoRunner(BaseRunner):
tbx_writer = _get_tbx_writer(get_tensorboard_log_dir(cfg.OUTPUT_DIR))
logger.info("Adding visualization evaluator ...")
mapper = self.get_mapper(cfg, is_train=False)
evaluator._evaluators.append(
self.get_visualization_evaluator()(
cfg,
tbx_writer,
mapper,
dataset_name,
train_iter=train_iter,
tag_postfix=model_tag,
vis_eval_type = self.get_visualization_evaluator()
if vis_eval_type is not None:
evaluator._evaluators.append(
vis_eval_type(
cfg,
tbx_writer,
mapper,
dataset_name,
train_iter=train_iter,
tag_postfix=model_tag,
)
)
)
results_per_dataset = inference_on_dataset(model, data_loader, evaluator)
......@@ -485,14 +487,14 @@ class Detectron2GoRunner(BaseRunner):
data_loader = build_d2go_train_loader(cfg, mapper)
if comm.is_main_process():
tbx_writer = _get_tbx_writer(get_tensorboard_log_dir(cfg.OUTPUT_DIR))
data_loader = cls.get_data_loader_vis_wrapper()(
cfg, tbx_writer, data_loader
)
data_loader_type = cls.get_data_loader_vis_wrapper()
if data_loader_type is not None:
tbx_writer = _get_tbx_writer(get_tensorboard_log_dir(cfg.OUTPUT_DIR))
data_loader = data_loader_type(cfg, tbx_writer, data_loader)
return data_loader
@staticmethod
def get_data_loader_vis_wrapper() -> Type[DataLoaderVisWrapper]:
def get_data_loader_vis_wrapper() -> Optional[Type[DataLoaderVisWrapper]]:
return DataLoaderVisWrapper
@staticmethod
......@@ -524,7 +526,7 @@ class Detectron2GoRunner(BaseRunner):
return mapper
@staticmethod
def get_visualization_evaluator() -> Type[VisualizationEvaluator]:
def get_visualization_evaluator() -> Optional[Type[VisualizationEvaluator]]:
return VisualizationEvaluator
@staticmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment