"git@developer.sourcefind.cn:OpenDAS/apex.git" did not exist on "59e992da71b678fbdbeb6181f8dd8fd81bdc495b"
Commit 7778f667 authored by Sam Tsai's avatar Sam Tsai Committed by Facebook GitHub Bot
Browse files

support using specified registration function for adhoc datasets

Summary:
Pull Request resolved: https://github.com/facebookresearch/mobile-vision/pull/61

Pull Request resolved: https://github.com/facebookresearch/d2go/pull/177

Adhoc datasets currently use default register functions. Changed to checking if it was registered in a look up table for injected coco and just using that instead.

Differential Revision: D33489049

fbshipit-source-id: bcb12bba49749a875ea80ae61f4eecc4a5d1e31a
parent 00409af8
...@@ -6,6 +6,7 @@ import functools ...@@ -6,6 +6,7 @@ import functools
import importlib import importlib
import logging import logging
import os import os
from collections import namedtuple
from d2go.utils.helper import get_dir_path from d2go.utils.helper import get_dir_path
from d2go.utils.misc import fb_overwritable from d2go.utils.misc import fb_overwritable
...@@ -27,6 +28,10 @@ COCO_REGISTER_FUNCTION_REGISTRY = Registry("COCO_REGISTER_FUNCTION_REGISTRY") ...@@ -27,6 +28,10 @@ COCO_REGISTER_FUNCTION_REGISTRY = Registry("COCO_REGISTER_FUNCTION_REGISTRY")
COCO_REGISTER_FUNCTION_REGISTRY.__doc__ = "Registry - coco register function" COCO_REGISTER_FUNCTION_REGISTRY.__doc__ = "Registry - coco register function"
InjectedCocoEntry = namedtuple("InjectedCocoEntry", ["func", "split_dict"])
INJECTED_COCO_DATASETS_LUT = {}
def get_coco_register_function(cfg): def get_coco_register_function(cfg):
name = cfg.D2GO_DATA.DATASETS.COCO_INJECTION.REGISTER_FUNCTION name = cfg.D2GO_DATA.DATASETS.COCO_INJECTION.REGISTER_FUNCTION
return COCO_REGISTER_FUNCTION_REGISTRY.get(name) return COCO_REGISTER_FUNCTION_REGISTRY.get(name)
...@@ -134,6 +139,9 @@ def inject_coco_datasets(cfg): ...@@ -134,6 +139,9 @@ def inject_coco_datasets(cfg):
split_dict["meta_data"] = get_keypoint_metadata(metadata_type[ds_index]) split_dict["meta_data"] = get_keypoint_metadata(metadata_type[ds_index])
logger.info("Inject coco dataset {}: {}".format(name, split_dict)) logger.info("Inject coco dataset {}: {}".format(name, split_dict))
_register_func(name, split_dict) _register_func(name, split_dict)
INJECTED_COCO_DATASETS_LUT[name] = InjectedCocoEntry(
func=_register_func, split_dict=split_dict
)
def register_dataset_split(dataset_name, split_dict): def register_dataset_split(dataset_name, split_dict):
......
...@@ -19,7 +19,12 @@ logger = logging.getLogger(__name__) ...@@ -19,7 +19,12 @@ logger = logging.getLogger(__name__)
from d2go.config import temp_defrost from d2go.config import temp_defrost
from d2go.data.datasets import register_dataset_split, ANN_FN, IM_DIR from d2go.data.datasets import (
register_dataset_split,
ANN_FN,
IM_DIR,
INJECTED_COCO_DATASETS_LUT,
)
from detectron2.data import DatasetCatalog, MetadataCatalog from detectron2.data import DatasetCatalog, MetadataCatalog
from detectron2.utils.file_io import PathManager from detectron2.utils.file_io import PathManager
...@@ -137,6 +142,12 @@ class AdhocCOCODataset(AdhocDataset): ...@@ -137,6 +142,12 @@ class AdhocCOCODataset(AdhocDataset):
if isinstance(load_func, CallFuncWithJsonFile): if isinstance(load_func, CallFuncWithJsonFile):
new_func = CallFuncWithJsonFile(func=load_func.func, json_file=tmp_file) new_func = CallFuncWithJsonFile(func=load_func.func, json_file=tmp_file)
DatasetCatalog.register(self.new_ds_name, new_func) DatasetCatalog.register(self.new_ds_name, new_func)
elif self.src_ds_name in INJECTED_COCO_DATASETS_LUT:
_src_func, _src_dict = INJECTED_COCO_DATASETS_LUT[self.src_ds_name]
_src_func(
self.new_ds_name,
split_dict={**_src_dict, ANN_FN: tmp_file, IM_DIR: metadata.image_root},
)
else: else:
# NOTE: only supports COCODataset as DS_TYPE since we cannot reconstruct # NOTE: only supports COCODataset as DS_TYPE since we cannot reconstruct
# the split_dict # the split_dict
......
...@@ -8,6 +8,7 @@ import tempfile ...@@ -8,6 +8,7 @@ import tempfile
import unittest import unittest
import d2go.data.extended_coco as extended_coco import d2go.data.extended_coco as extended_coco
from d2go.data.datasets import COCO_REGISTER_FUNCTION_REGISTRY, ANN_FN, IM_DIR
from d2go.data.keypoint_metadata_registry import ( from d2go.data.keypoint_metadata_registry import (
KEYPOINT_METADATA_REGISTRY, KEYPOINT_METADATA_REGISTRY,
KeypointMetadata, KeypointMetadata,
...@@ -322,3 +323,87 @@ class TestD2GoDatasets(unittest.TestCase): ...@@ -322,3 +323,87 @@ class TestD2GoDatasets(unittest.TestCase):
) )
ds_list = DatasetCatalog.get("test_adhoc_ds2@1classes") ds_list = DatasetCatalog.get("test_adhoc_ds2@1classes")
self.assertEqual(len(ds_list), 5) self.assertEqual(len(ds_list), 5)
@tempdir
def test_register_coco_dataset_registry(self, tmp_dir):
dummy_buffer = []
@COCO_REGISTER_FUNCTION_REGISTRY.register()
def _register_dummy_function_coco(dataset_name, split_dict):
dummy_buffer.append((dataset_name, split_dict))
image_dir, json_file = create_test_images_and_dataset_json(tmp_dir)
runner = Detectron2GoRunner()
cfg = runner.get_default_cfg()
cfg.merge_from_list(
[
str(x)
for x in [
"D2GO_DATA.DATASETS.COCO_INJECTION.NAMES",
["inj_test_registry"],
"D2GO_DATA.DATASETS.COCO_INJECTION.IM_DIRS",
[image_dir],
"D2GO_DATA.DATASETS.COCO_INJECTION.JSON_FILES",
[json_file],
"D2GO_DATA.DATASETS.COCO_INJECTION.REGISTER_FUNCTION",
"_register_dummy_function_coco",
]
]
)
runner.register(cfg)
self.assertTrue(len(dummy_buffer) == 1)
@tempdir
def test_adhoc_register_coco_dataset_registry(self, tmp_dir):
dummy_buffer = []
def _dummy_load_func():
return []
@COCO_REGISTER_FUNCTION_REGISTRY.register()
def _register_dummy_function_coco_adhoc(dataset_name, split_dict):
json_file = split_dict[ANN_FN]
image_root = split_dict[IM_DIR]
DatasetCatalog.register(dataset_name, _dummy_load_func)
MetadataCatalog.get(dataset_name).set(
evaluator_type="coco",
json_file=json_file,
image_root=image_root,
)
dummy_buffer.append((dataset_name, split_dict))
image_dir, json_file = create_test_images_and_dataset_json(tmp_dir)
runner = Detectron2GoRunner()
cfg = runner.get_default_cfg()
cfg.merge_from_list(
[
str(x)
for x in [
"D2GO_DATA.DATASETS.COCO_INJECTION.NAMES",
["inj_test_registry_adhoc"],
"D2GO_DATA.DATASETS.COCO_INJECTION.IM_DIRS",
[image_dir],
"D2GO_DATA.DATASETS.COCO_INJECTION.JSON_FILES",
[json_file],
"D2GO_DATA.DATASETS.COCO_INJECTION.REGISTER_FUNCTION",
"_register_dummy_function_coco_adhoc",
]
]
)
runner.register(cfg)
self.assertTrue(len(dummy_buffer) == 1)
# Add adhoc class that uses only the first class
AdhocDatasetManager.add(
COCOWithClassesToUse("inj_test_registry_adhoc", ["class_0"])
)
# Check that the correct register function is used
self.assertTrue(len(dummy_buffer) == 2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment