Commit a0ee06f3 authored by Owen Wang's avatar Owen Wang Committed by Facebook GitHub Bot
Browse files

add .npy file handling in evaluator and visualizer

Summary: Detectron2[Go]'s Visualizer and sem_seg_evaluation now updated with customization entrypoints for how to handle reading semantic seg masks. By default, PIL and PNG images are expected. However, some specific projects' datasets use .npy files and this customization allows providing an alternate Visualizer and evaluation function for reading them.

Reviewed By: newstzpz

Differential Revision: D33434948

fbshipit-source-id: 42af16d6708ffc5b2c03ec8507757313e23c8204
parent 2fe42c47
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import Type, Optional
from detectron2.data import DatasetCatalog, MetadataCatalog, detection_utils as utils from detectron2.data import DatasetCatalog, MetadataCatalog, detection_utils as utils
from detectron2.evaluation import DatasetEvaluator from detectron2.evaluation import DatasetEvaluator
from detectron2.modeling import META_ARCH_REGISTRY from detectron2.modeling import META_ARCH_REGISTRY
...@@ -16,8 +18,9 @@ class VisualizerWrapper(object): ...@@ -16,8 +18,9 @@ class VisualizerWrapper(object):
the high-level interface for visualizing. the high-level interface for visualizing.
""" """
def __init__(self, cfg): def __init__(self, cfg, custom_visualizer: Optional[Type[Visualizer]] = None):
self.cfg = cfg self.cfg = cfg
self.visualizer = custom_visualizer or Visualizer
def _get_meta_arch_class(self): def _get_meta_arch_class(self):
return META_ARCH_REGISTRY.get(self.cfg.MODEL.META_ARCHITECTURE) return META_ARCH_REGISTRY.get(self.cfg.MODEL.META_ARCHITECTURE)
...@@ -38,7 +41,7 @@ class VisualizerWrapper(object): ...@@ -38,7 +41,7 @@ class VisualizerWrapper(object):
img = utils.convert_image_to_rgb(img, cfg.INPUT.FORMAT) img = utils.convert_image_to_rgb(img, cfg.INPUT.FORMAT)
metadata = MetadataCatalog.get(cfg.DATASETS.TRAIN[0]) metadata = MetadataCatalog.get(cfg.DATASETS.TRAIN[0])
scale = 2.0 scale = 2.0
visualizer = Visualizer(img, metadata=metadata, scale=scale) visualizer = self.visualizer(img, metadata=metadata, scale=scale)
if "instances" in per_image: if "instances" in per_image:
target_fields = per_image["instances"].get_fields() target_fields = per_image["instances"].get_fields()
...@@ -69,7 +72,7 @@ class VisualizerWrapper(object): ...@@ -69,7 +72,7 @@ class VisualizerWrapper(object):
) )
image = dataset_mapper._read_image(input_dict, "RGB") image = dataset_mapper._read_image(input_dict, "RGB")
visualizer = Visualizer(image, metadata=MetadataCatalog.get(dataset_name)) visualizer = self.visualizer(image, metadata=MetadataCatalog.get(dataset_name))
if "panoptic_seg" in output_dict: if "panoptic_seg" in output_dict:
# NOTE: refer to https://fburl.com/diffusion/evarrhbh # NOTE: refer to https://fburl.com/diffusion/evarrhbh
...@@ -90,7 +93,7 @@ class VisualizerWrapper(object): ...@@ -90,7 +93,7 @@ class VisualizerWrapper(object):
Visualize the dataset_dict Visualize the dataset_dict
""" """
image = dataset_mapper._read_image(dataset_dict, "RGB") image = dataset_mapper._read_image(dataset_dict, "RGB")
visualizer = Visualizer(image, metadata=MetadataCatalog.get(dataset_name)) visualizer = self.visualizer(image, metadata=MetadataCatalog.get(dataset_name))
visualizer.draw_dataset_dict(dataset_dict) visualizer.draw_dataset_dict(dataset_dict)
return visualizer.get_output().get_image() return visualizer.get_output().get_image()
...@@ -100,10 +103,16 @@ class DataLoaderVisWrapper: ...@@ -100,10 +103,16 @@ class DataLoaderVisWrapper:
Wrap the data loader to visualize its output via TensorBoardX at given frequency. Wrap the data loader to visualize its output via TensorBoardX at given frequency.
""" """
def __init__(self, cfg, tbx_writer, data_loader): def __init__(
self,
cfg,
tbx_writer,
data_loader,
visualizer: Optional[Type[VisualizerWrapper]] = None,
):
self.tbx_writer = tbx_writer self.tbx_writer = tbx_writer
self.data_loader = data_loader self.data_loader = data_loader
self._visualizer = VisualizerWrapper(cfg) self._visualizer = visualizer(cfg) if visualizer else VisualizerWrapper(cfg)
self.log_frequency = cfg.TENSORBOARD.TRAIN_LOADER_VIS_WRITE_PERIOD self.log_frequency = cfg.TENSORBOARD.TRAIN_LOADER_VIS_WRITE_PERIOD
self.log_limit = cfg.TENSORBOARD.TRAIN_LOADER_VIS_MAX_IMAGES self.log_limit = cfg.TENSORBOARD.TRAIN_LOADER_VIS_MAX_IMAGES
...@@ -179,11 +188,12 @@ class VisualizationEvaluator(DatasetEvaluator): ...@@ -179,11 +188,12 @@ class VisualizationEvaluator(DatasetEvaluator):
dataset_name, dataset_name,
train_iter=None, train_iter=None,
tag_postfix=None, tag_postfix=None,
visualizer: Optional[Type[VisualizerWrapper]] = None,
): ):
self.tbx_writer = tbx_writer self.tbx_writer = tbx_writer
self.dataset_mapper = dataset_mapper self.dataset_mapper = dataset_mapper
self.dataset_name = dataset_name self.dataset_name = dataset_name
self._visualizer = VisualizerWrapper(cfg) self._visualizer = visualizer(cfg) if visualizer else VisualizerWrapper(cfg)
self.train_iter = train_iter or VisualizationEvaluator._counter self.train_iter = train_iter or VisualizationEvaluator._counter
self.tag_postfix = tag_postfix or "" self.tag_postfix = tag_postfix or ""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment