Commit 77ebe09f authored by Sam Tsai's avatar Sam Tsai Committed by Facebook GitHub Bot
Browse files

add keypoints metadata registry

Summary:
1. Add a keypoint metadata registry for registering different keypoint metadata
2. Add option to inject_coco_dataset for adding keypoint metadata

Reviewed By: newstzpz

Differential Revision: D27730541

fbshipit-source-id: c6ba97f60664fce4dcbb0de80222df7490bc6d5d
parent 2a1ed1ec
......@@ -17,10 +17,17 @@ def add_d2go_data_default_configs(_C):
_C.D2GO_DATA.DATASETS.TEST_CATEGORIES = ()
# Register a list of COCO datasets in config
# The following specifies additional coco data to inject. The required is the
# name (NAMES), image root (IM_DIRS), coco json file (JSON_FILES) while keypoint
# metadata (KEYPOINT_METADATA) is optional. The keypoint metadata name provided
# here is used to lookup the metadata specified within the KEYPOINT_METADATA
# metadata registry specified in "data/keypoint_metadata_registry.py". For adding
# new use cases, simply register new metadata to that registry.
_C.D2GO_DATA.DATASETS.COCO_INJECTION = CN()
_C.D2GO_DATA.DATASETS.COCO_INJECTION.NAMES = []
_C.D2GO_DATA.DATASETS.COCO_INJECTION.IM_DIRS = []
_C.D2GO_DATA.DATASETS.COCO_INJECTION.JSON_FILES = []
_C.D2GO_DATA.DATASETS.COCO_INJECTION.KEYPOINT_METADATA = []
# On-the-fly register a list of datasets located under detectron2go/datasets
# by specifying the filename (without .py).
......
......@@ -7,11 +7,12 @@ import importlib
import logging
import os
from detectron2.data import DatasetCatalog, MetadataCatalog
from d2go.utils.helper import get_dir_path
from detectron2.data import DatasetCatalog, MetadataCatalog
from .extended_coco import coco_text_load, extended_coco_load
from .extended_lvis import extended_lvis_load
from .keypoint_metadata_registry import get_keypoint_metadata
logger = logging.getLogger(__name__)
......@@ -104,10 +105,15 @@ def inject_coco_datasets(cfg):
names = cfg.D2GO_DATA.DATASETS.COCO_INJECTION.NAMES
im_dirs = cfg.D2GO_DATA.DATASETS.COCO_INJECTION.IM_DIRS
json_files = cfg.D2GO_DATA.DATASETS.COCO_INJECTION.JSON_FILES
metadata_type = cfg.D2GO_DATA.DATASETS.COCO_INJECTION.KEYPOINT_METADATA
assert len(names) == len(im_dirs) == len(json_files)
for name, im_dir, json_file in zip(names, im_dirs, json_files):
for ds_index, (name, im_dir, json_file) in enumerate(
zip(names, im_dirs, json_files)
):
split_dict = {IM_DIR: im_dir, ANN_FN: json_file}
if len(metadata_type) != 0:
split_dict["meta_data"] = get_keypoint_metadata(metadata_type[ds_index])
logger.info("Inject coco dataset {}: {}".format(name, split_dict))
_register_extended_coco(name, split_dict)
......
#!/usr/bin/env python3
from typing import NamedTuple, List, Tuple
from detectron2.utils.registry import Registry
KEYPOINT_METADATA_REGISTRY = Registry("KEYPOINT_METADATA")
KEYPOINT_METADATA_REGISTRY.__doc__ = "Registry keypoint metadata definitions"
class KeypointMetadata(NamedTuple):
names: List[str]
flip_map: List[Tuple[str, str]]
connection_rules: List[Tuple[str, str, Tuple[int, int, int]]]
def to_dict(self):
return {
"keypoint_names": self.names,
"keypoint_flip_map": self.flip_map,
"keypoint_connection_rules": self.connection_rules,
}
def get_keypoint_metadata(name):
return KEYPOINT_METADATA_REGISTRY.get(name)().to_dict()
......@@ -7,13 +7,19 @@ import os
import unittest
import d2go.data.extended_coco as extended_coco
from d2go.data.keypoint_metadata_registry import (
KEYPOINT_METADATA_REGISTRY,
KeypointMetadata,
get_keypoint_metadata,
)
from d2go.data.utils import maybe_subsample_n_images
from d2go.runner import Detectron2GoRunner
from d2go.utils.testing.data_loader_helper import (
LocalImageGenerator,
create_toy_dataset,
)
from detectron2.data import DatasetCatalog
from d2go.utils.testing.helper import tempdir
from detectron2.data import DatasetCatalog, MetadataCatalog
from mobile_cv.common.misc.file_utils import make_temp_directory
......@@ -68,9 +74,8 @@ class TestD2GoDatasets(unittest.TestCase):
self.assertEqual(out_json["images"][0]["id"], exp_output[0])
self.assertEqual(out_json["annotations"][0]["image_id"], exp_output[1])
def test_coco_injection(self):
with make_temp_directory("detectron2go_tmp_dataset") as tmp_dir:
@tempdir
def test_coco_injection(self, tmp_dir):
image_dir, json_file = create_test_images_and_dataset_json(tmp_dir)
runner = Detectron2GoRunner()
......@@ -96,8 +101,8 @@ class TestD2GoDatasets(unittest.TestCase):
self.assertEqual(dic["width"], 80)
self.assertEqual(dic["height"], 60)
def test_sub_dataset(self):
with make_temp_directory("detectron2go_tmp_dataset") as tmp_dir:
@tempdir
def test_sub_dataset(self, tmp_dir):
image_dir, json_file = create_test_images_and_dataset_json(tmp_dir)
runner = Detectron2GoRunner()
......@@ -126,3 +131,60 @@ class TestD2GoDatasets(unittest.TestCase):
new_cfg, new_cfg.DATASETS.TEST[0]
)
self.assertEqual(len(test_loader), 1)
def test_coco_metadata_registry(self):
@KEYPOINT_METADATA_REGISTRY.register()
def TriangleMetadata():
return KeypointMetadata(
names=("A", "B", "C"),
flip_map=(
("A", "B"),
("B", "C"),
),
connection_rules=[
("A", "B", (102, 204, 255)),
("B", "C", (51, 153, 255)),
],
)
tri_md = get_keypoint_metadata("TriangleMetadata")
self.assertEqual(tri_md["keypoint_names"][0], "A")
self.assertEqual(tri_md["keypoint_flip_map"][0][0], "A")
self.assertEqual(tri_md["keypoint_connection_rules"][0][0], "A")
@tempdir
def test_coco_metadata_register(self, tmp_dir):
@KEYPOINT_METADATA_REGISTRY.register()
def LineMetadata():
return KeypointMetadata(
names=("A", "B"),
flip_map=(("A", "B"),),
connection_rules=[
("A", "B", (102, 204, 255)),
],
)
image_dir, json_file = create_test_images_and_dataset_json(tmp_dir)
runner = Detectron2GoRunner()
cfg = runner.get_default_cfg()
cfg.merge_from_list(
[
str(x)
for x in [
"D2GO_DATA.DATASETS.COCO_INJECTION.NAMES",
["inj_ds"],
"D2GO_DATA.DATASETS.COCO_INJECTION.IM_DIRS",
[image_dir],
"D2GO_DATA.DATASETS.COCO_INJECTION.JSON_FILES",
[json_file],
"D2GO_DATA.DATASETS.COCO_INJECTION.KEYPOINT_METADATA",
["LineMetadata"],
]
]
)
runner.register(cfg)
inj_md = MetadataCatalog.get("inj_ds")
self.assertEqual(inj_md.keypoint_names[0], "A")
self.assertEqual(inj_md.keypoint_flip_map[0][0], "A")
self.assertEqual(inj_md.keypoint_connection_rules[0][0], "A")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment