test_d2go_datasets.py 3.11 KB
Newer Older
facebook-github-bot's avatar
facebook-github-bot committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved


import json
import os
import unittest

from detectron2.data import DatasetCatalog
from d2go.data.utils import maybe_subsample_n_images
from d2go.runner import Detectron2GoRunner
from mobile_cv.common.misc.file_utils import make_temp_directory

from d2go.tests.data_loader_helper import LocalImageGenerator, create_toy_dataset


def create_test_images_and_dataset_json(data_dir):
    # create image and json
    image_dir = os.path.join(data_dir, "images")
    os.makedirs(image_dir)
    json_dataset, meta_data = create_toy_dataset(
        LocalImageGenerator(image_dir, width=80, height=60), num_images=10
    )
    json_file = os.path.join(data_dir, "{}.json".format("inj_ds1"))
    with open(json_file, "w") as f:
        json.dump(json_dataset, f)

    return image_dir, json_file


class TestD2GoDatasets(unittest.TestCase):
    def test_coco_injection(self):

        with make_temp_directory("detectron2go_tmp_dataset") as tmp_dir:
            image_dir, json_file = create_test_images_and_dataset_json(tmp_dir)

            runner = Detectron2GoRunner()
            cfg = runner.get_default_cfg()
            cfg.merge_from_list(
                [
                    str(x)
                    for x in [
                        "D2GO_DATA.DATASETS.COCO_INJECTION.NAMES",
                        ["inj_ds1", "inj_ds2"],
                        "D2GO_DATA.DATASETS.COCO_INJECTION.IM_DIRS",
                        [image_dir, "/mnt/fair"],
                        "D2GO_DATA.DATASETS.COCO_INJECTION.JSON_FILES",
                        [json_file, "inj_ds2"],
                    ]
                ]
            )

            runner.register(cfg)
            inj_ds1 = DatasetCatalog.get("inj_ds1")
            self.assertEqual(len(inj_ds1), 10)
            for dic in inj_ds1:
                self.assertEqual(dic["width"], 80)
                self.assertEqual(dic["height"], 60)

    def test_sub_dataset(self):
        with make_temp_directory("detectron2go_tmp_dataset") as tmp_dir:
            image_dir, json_file = create_test_images_and_dataset_json(tmp_dir)

            runner = Detectron2GoRunner()
            cfg = runner.get_default_cfg()
            cfg.merge_from_list(
                [
                    str(x)
                    for x in [
                        "D2GO_DATA.DATASETS.COCO_INJECTION.NAMES",
                        ["inj_ds"],
                        "D2GO_DATA.DATASETS.COCO_INJECTION.IM_DIRS",
                        [image_dir],
                        "D2GO_DATA.DATASETS.COCO_INJECTION.JSON_FILES",
                        [json_file],
                        "DATASETS.TEST",
                        ("inj_ds",),
                        "D2GO_DATA.TEST.MAX_IMAGES",
                        1,
                    ]
                ]
            )

            runner.register(cfg)
            with maybe_subsample_n_images(cfg) as new_cfg:
                test_loader = runner.build_detection_test_loader(
                    new_cfg, new_cfg.DATASETS.TEST[0]
                )
                self.assertEqual(len(test_loader), 1)