"docs/git@developer.sourcefind.cn:change/sglang.git" did not exist on "43e20c06479ad8e1bbb640c159bb3a1bf76d70c8"
Commit d4aedb83 authored by Sam Tsai's avatar Sam Tsai Committed by Facebook GitHub Bot
Browse files

use src dataset name instead of the derived class name

Summary: "@ [0-9]classes" is appended to datasets to mark whether it is a derived class of the original one and saved as a config. When reloading the config, the derived class name will be used as the source instead of the original source. Adding a check to remove the derived suffix.

Reviewed By: wat3rBro

Differential Revision: D29315132

fbshipit-source-id: 0cc204d305d2da6c9f1817aaf631270bd874f90d
parent c480d4e4
...@@ -7,6 +7,7 @@ import contextlib ...@@ -7,6 +7,7 @@ import contextlib
import json import json
import logging import logging
import os import os
import re
import shutil import shutil
import tempfile import tempfile
from collections import defaultdict from collections import defaultdict
...@@ -229,6 +230,11 @@ class COCOSubsetWithGivenImages(AdhocCOCODataset): ...@@ -229,6 +230,11 @@ class COCOSubsetWithGivenImages(AdhocCOCODataset):
class COCOWithClassesToUse(AdhocCOCODataset): class COCOWithClassesToUse(AdhocCOCODataset):
def __init__(self, src_ds_name, classes_to_use): def __init__(self, src_ds_name, classes_to_use):
# check if name is already a derived class and try to reverse it
res = re.match("(?P<src>.+)@(?P<num>[0-9]+)classes", src_ds_name)
if res is not None:
src_ds_name = res['src']
super().__init__( super().__init__(
src_ds_name=src_ds_name, src_ds_name=src_ds_name,
new_ds_name="{}@{}classes".format(src_ds_name, len(classes_to_use)), new_ds_name="{}@{}classes".format(src_ds_name, len(classes_to_use)),
......
...@@ -12,7 +12,11 @@ from d2go.data.keypoint_metadata_registry import ( ...@@ -12,7 +12,11 @@ from d2go.data.keypoint_metadata_registry import (
KeypointMetadata, KeypointMetadata,
get_keypoint_metadata, get_keypoint_metadata,
) )
from d2go.data.utils import maybe_subsample_n_images from d2go.data.utils import (
maybe_subsample_n_images,
AdhocDatasetManager,
COCOWithClassesToUse,
)
from d2go.runner import Detectron2GoRunner from d2go.runner import Detectron2GoRunner
from d2go.utils.testing.data_loader_helper import ( from d2go.utils.testing.data_loader_helper import (
LocalImageGenerator, LocalImageGenerator,
...@@ -23,12 +27,14 @@ from detectron2.data import DatasetCatalog, MetadataCatalog ...@@ -23,12 +27,14 @@ from detectron2.data import DatasetCatalog, MetadataCatalog
from mobile_cv.common.misc.file_utils import make_temp_directory from mobile_cv.common.misc.file_utils import make_temp_directory
def create_test_images_and_dataset_json(data_dir): def create_test_images_and_dataset_json(data_dir, num_images=10, num_classes=-1):
# create image and json # create image and json
image_dir = os.path.join(data_dir, "images") image_dir = os.path.join(data_dir, "images")
os.makedirs(image_dir) os.makedirs(image_dir)
json_dataset, meta_data = create_toy_dataset( json_dataset, meta_data = create_toy_dataset(
LocalImageGenerator(image_dir, width=80, height=60), num_images=10 LocalImageGenerator(image_dir, width=80, height=60),
num_images=num_images,
num_classes=num_classes,
) )
json_file = os.path.join(data_dir, "{}.json".format("inj_ds1")) json_file = os.path.join(data_dir, "{}.json".format("inj_ds1"))
with open(json_file, "w") as f: with open(json_file, "w") as f:
...@@ -131,7 +137,7 @@ class TestD2GoDatasets(unittest.TestCase): ...@@ -131,7 +137,7 @@ class TestD2GoDatasets(unittest.TestCase):
"area": 100, "area": 100,
"bbox": [0, 0, 0, 0], "bbox": [0, 0, 0, 0],
}, },
] ],
] ]
out_dict_list = extended_coco.convert_to_dict_list( out_dict_list = extended_coco.convert_to_dict_list(
...@@ -256,3 +262,37 @@ class TestD2GoDatasets(unittest.TestCase): ...@@ -256,3 +262,37 @@ class TestD2GoDatasets(unittest.TestCase):
self.assertEqual(inj_md.keypoint_names[0], "A") self.assertEqual(inj_md.keypoint_names[0], "A")
self.assertEqual(inj_md.keypoint_flip_map[0][0], "A") self.assertEqual(inj_md.keypoint_flip_map[0][0], "A")
self.assertEqual(inj_md.keypoint_connection_rules[0][0], "A") self.assertEqual(inj_md.keypoint_connection_rules[0][0], "A")
@tempdir
def test_coco_create_adhoc_class_to_use_dataset(self, tmp_dir):
image_dir, json_file = create_test_images_and_dataset_json(
tmp_dir, num_classes=2
)
runner = Detectron2GoRunner()
cfg = runner.get_default_cfg()
cfg.merge_from_list(
[
str(x)
for x in [
"D2GO_DATA.DATASETS.COCO_INJECTION.NAMES",
["test_adhoc_ds", "test_adhoc_ds2"],
"D2GO_DATA.DATASETS.COCO_INJECTION.IM_DIRS",
[image_dir, image_dir],
"D2GO_DATA.DATASETS.COCO_INJECTION.JSON_FILES",
[json_file, json_file],
]
]
)
runner.register(cfg)
# Test adhoc classes to use
AdhocDatasetManager.add(COCOWithClassesToUse("test_adhoc_ds", ["class_0"]))
ds_list = DatasetCatalog.get("test_adhoc_ds@1classes")
self.assertEqual(len(ds_list), 5)
# Test adhoc classes to use with suffix removal
AdhocDatasetManager.add(COCOWithClassesToUse("test_adhoc_ds2@1classes", ["class_0"]))
ds_list = DatasetCatalog.get("test_adhoc_ds2@1classes")
self.assertEqual(len(ds_list), 5)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment