Commit 9c877fd4 authored by Sam Tsai's avatar Sam Tsai Committed by Facebook GitHub Bot
Browse files

support multiple image visualization in dataloader visualization wrapper

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/160

If the returned object of visualize_train_input is a dictionary, use the key as tag suffix and the values as separate output images.

Reviewed By: zhanghang1989, wat3rBro

Differential Revision: D33468573

fbshipit-source-id: b0a47ba312ff59700534e917c62af1dfa83dd5be
parent b6e244d2
...@@ -43,7 +43,7 @@ class VisualizerWrapper(object): ...@@ -43,7 +43,7 @@ class VisualizerWrapper(object):
if "instances" in per_image: if "instances" in per_image:
target_fields = per_image["instances"].get_fields() target_fields = per_image["instances"].get_fields()
labels = [metadata.thing_classes[i] for i in target_fields["gt_classes"]] labels = [metadata.thing_classes[i] for i in target_fields["gt_classes"]]
vis = visualizer.overlay_instances( visualizer.overlay_instances(
labels=labels, labels=labels,
boxes=target_fields.get("gt_boxes", None), boxes=target_fields.get("gt_boxes", None),
masks=target_fields.get("gt_masks", None), masks=target_fields.get("gt_masks", None),
...@@ -51,11 +51,9 @@ class VisualizerWrapper(object): ...@@ -51,11 +51,9 @@ class VisualizerWrapper(object):
) )
if "sem_seg" in per_image: if "sem_seg" in per_image:
vis = visualizer.draw_sem_seg( visualizer.draw_sem_seg(per_image["sem_seg"], area_threshold=0, alpha=0.5)
per_image["sem_seg"], area_threshold=0, alpha=0.5
)
return vis.get_image() return visualizer.get_output().get_image()
def visualize_test_output( def visualize_test_output(
self, dataset_name, dataset_mapper, input_dict, output_dict self, dataset_name, dataset_mapper, input_dict, output_dict
...@@ -139,17 +137,27 @@ class DataLoaderVisWrapper: ...@@ -139,17 +137,27 @@ class DataLoaderVisWrapper:
for i, per_image in enumerate(data): for i, per_image in enumerate(data):
vis_image = self._visualizer.visualize_train_input(per_image) vis_image = self._visualizer.visualize_train_input(per_image)
tag = "train_loader_batch_{}/".format(storage.iter) tag = [f"train_loader_batch_{storage.iter}"]
if "dataset_name" in per_image: if "dataset_name" in per_image:
tag += per_image["dataset_name"] + "/" tag += [per_image["dataset_name"]]
if "file_name" in per_image: if "file_name" in per_image:
tag += "img_{}/{}".format(i, per_image["file_name"]) tag += [f"img_{i}", per_image["file_name"]]
self.tbx_writer._writer.add_image(
tag=tag, if isinstance(vis_image, dict):
img_tensor=vis_image, for k in vis_image:
global_step=storage.iter, self.tbx_writer._writer.add_image(
dataformats="HWC", tag="/".join(tag + [k]),
) img_tensor=vis_image[k],
global_step=storage.iter,
dataformats="HWC",
)
else:
self.tbx_writer._writer.add_image(
tag="/".join(tag),
img_tensor=vis_image,
global_step=storage.iter,
dataformats="HWC",
)
class VisualizationEvaluator(DatasetEvaluator): class VisualizationEvaluator(DatasetEvaluator):
......
...@@ -9,6 +9,7 @@ import unittest ...@@ -9,6 +9,7 @@ import unittest
from typing import Optional, List, Tuple, Dict from typing import Optional, List, Tuple, Dict
import d2go.runner.default_runner as default_runner import d2go.runner.default_runner as default_runner
import numpy as np
import torch import torch
from d2go.utils.testing.data_loader_helper import ( from d2go.utils.testing.data_loader_helper import (
LocalImageGenerator, LocalImageGenerator,
...@@ -17,6 +18,7 @@ from d2go.utils.testing.data_loader_helper import ( ...@@ -17,6 +18,7 @@ from d2go.utils.testing.data_loader_helper import (
from d2go.utils.testing.helper import tempdir from d2go.utils.testing.helper import tempdir
from d2go.utils.visualization import VisualizerWrapper, DataLoaderVisWrapper from d2go.utils.visualization import VisualizerWrapper, DataLoaderVisWrapper
from detectron2.data import DatasetCatalog from detectron2.data import DatasetCatalog
from detectron2.modeling import META_ARCH_REGISTRY
from detectron2.structures import Boxes, Instances from detectron2.structures import Boxes, Instances
from detectron2.utils.events import EventStorage from detectron2.utils.events import EventStorage
...@@ -52,6 +54,13 @@ def create_dummy_input_dict( ...@@ -52,6 +54,13 @@ def create_dummy_input_dict(
return input_dict return input_dict
@META_ARCH_REGISTRY.register()
class DummyMetaArch(torch.nn.Module):
@staticmethod
def visualize_train_input(visualizer_wrapper, input_dict):
return {"default": np.zeros((60, 60, 30)), "secondary": np.zeros((60, 60, 30))}
class ImageDictStore: class ImageDictStore:
def __init__(self): def __init__(self):
self.write_buffer = [] self.write_buffer = []
...@@ -140,3 +149,51 @@ class TestVisualization(unittest.TestCase): ...@@ -140,3 +149,51 @@ class TestVisualization(unittest.TestCase):
self.assertTrue("tag" in vis_image_dict) self.assertTrue("tag" in vis_image_dict)
self.assertTrue("img_tensor" in vis_image_dict) self.assertTrue("img_tensor" in vis_image_dict)
self.assertTrue("global_step" in vis_image_dict) self.assertTrue("global_step" in vis_image_dict)
@tempdir
def test_dict_based_dataloader_visualizer_wrapper(self, tmp_dir: str):
image_dir, json_file = create_test_images_and_dataset_json(tmp_dir, 60, 60)
# Create config data
runner = default_runner.Detectron2GoRunner()
cfg = runner.get_default_cfg()
cfg.merge_from_list(
[
"D2GO_DATA.DATASETS.COCO_INJECTION.NAMES",
str(["inj_ds3"]),
"D2GO_DATA.DATASETS.COCO_INJECTION.IM_DIRS",
str([image_dir]),
"D2GO_DATA.DATASETS.COCO_INJECTION.JSON_FILES",
str([json_file]),
"DATASETS.TRAIN",
str(["inj_ds3"]),
"MODEL.META_ARCHITECTURE",
"DummyMetaArch",
]
)
# Register configs
runner.register(cfg)
DatasetCatalog.get("inj_ds3")
with EventStorage():
# Create mock storage for writer
mock_tbx_writer = MockTbxWriter()
# Create a wrapper around an iterable object and run once
input_dict = create_dummy_input_dict(60, 60, [[1, 1, 2, 2]])
dl_wrapper = DataLoaderVisWrapper(
cfg, mock_tbx_writer, [[input_dict], [input_dict]]
)
for _ in dl_wrapper:
break
# Check data has been written to buffer
self.assertTrue(len(mock_tbx_writer._writer.write_buffer) == 2)
self.assertTrue(
"train_loader_batch_0/default"
in mock_tbx_writer._writer.write_buffer[0]["tag"]
)
self.assertTrue(
"train_loader_batch_0/secondary"
in mock_tbx_writer._writer.write_buffer[1]["tag"]
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment