Commit d4aedb83 authored by Sam Tsai's avatar Sam Tsai Committed by Facebook GitHub Bot
Browse files

use src dataset name instead of the derived class name

Summary: "@ [0-9]classes" is appended to datasets to mark whether it is a derived class of the original one and saved as a config. When reloading the config, the derived class name will be used as the source instead of the original source. Adding a check to remove the derived suffix.

Reviewed By: wat3rBro

Differential Revision: D29315132

fbshipit-source-id: 0cc204d305d2da6c9f1817aaf631270bd874f90d
parent c480d4e4
......@@ -7,6 +7,7 @@ import contextlib
import json
import logging
import os
import re
import shutil
import tempfile
from collections import defaultdict
......@@ -229,6 +230,11 @@ class COCOSubsetWithGivenImages(AdhocCOCODataset):
class COCOWithClassesToUse(AdhocCOCODataset):
def __init__(self, src_ds_name, classes_to_use):
# check if name is already a derived class and try to reverse it
res = re.match("(?P<src>.+)@(?P<num>[0-9]+)classes", src_ds_name)
if res is not None:
src_ds_name = res['src']
super().__init__(
src_ds_name=src_ds_name,
new_ds_name="{}@{}classes".format(src_ds_name, len(classes_to_use)),
......
......@@ -12,7 +12,11 @@ from d2go.data.keypoint_metadata_registry import (
KeypointMetadata,
get_keypoint_metadata,
)
from d2go.data.utils import maybe_subsample_n_images
from d2go.data.utils import (
maybe_subsample_n_images,
AdhocDatasetManager,
COCOWithClassesToUse,
)
from d2go.runner import Detectron2GoRunner
from d2go.utils.testing.data_loader_helper import (
LocalImageGenerator,
......@@ -23,12 +27,14 @@ from detectron2.data import DatasetCatalog, MetadataCatalog
from mobile_cv.common.misc.file_utils import make_temp_directory
def create_test_images_and_dataset_json(data_dir):
def create_test_images_and_dataset_json(data_dir, num_images=10, num_classes=-1):
# create image and json
image_dir = os.path.join(data_dir, "images")
os.makedirs(image_dir)
json_dataset, meta_data = create_toy_dataset(
LocalImageGenerator(image_dir, width=80, height=60), num_images=10
LocalImageGenerator(image_dir, width=80, height=60),
num_images=num_images,
num_classes=num_classes,
)
json_file = os.path.join(data_dir, "{}.json".format("inj_ds1"))
with open(json_file, "w") as f:
......@@ -131,7 +137,7 @@ class TestD2GoDatasets(unittest.TestCase):
"area": 100,
"bbox": [0, 0, 0, 0],
},
]
],
]
out_dict_list = extended_coco.convert_to_dict_list(
......@@ -256,3 +262,37 @@ class TestD2GoDatasets(unittest.TestCase):
self.assertEqual(inj_md.keypoint_names[0], "A")
self.assertEqual(inj_md.keypoint_flip_map[0][0], "A")
self.assertEqual(inj_md.keypoint_connection_rules[0][0], "A")
@tempdir
def test_coco_create_adhoc_class_to_use_dataset(self, tmp_dir):
image_dir, json_file = create_test_images_and_dataset_json(
tmp_dir, num_classes=2
)
runner = Detectron2GoRunner()
cfg = runner.get_default_cfg()
cfg.merge_from_list(
[
str(x)
for x in [
"D2GO_DATA.DATASETS.COCO_INJECTION.NAMES",
["test_adhoc_ds", "test_adhoc_ds2"],
"D2GO_DATA.DATASETS.COCO_INJECTION.IM_DIRS",
[image_dir, image_dir],
"D2GO_DATA.DATASETS.COCO_INJECTION.JSON_FILES",
[json_file, json_file],
]
]
)
runner.register(cfg)
# Test adhoc classes to use
AdhocDatasetManager.add(COCOWithClassesToUse("test_adhoc_ds", ["class_0"]))
ds_list = DatasetCatalog.get("test_adhoc_ds@1classes")
self.assertEqual(len(ds_list), 5)
# Test adhoc classes to use with suffix removal
AdhocDatasetManager.add(COCOWithClassesToUse("test_adhoc_ds2@1classes", ["class_0"]))
ds_list = DatasetCatalog.get("test_adhoc_ds2@1classes")
self.assertEqual(len(ds_list), 5)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment