test_data_loader.py 4.79 KB
Newer Older
facebook-github-bot's avatar
facebook-github-bot committed
1
2
3
4
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved


5
6
import os
import shutil
Yanghan Wang's avatar
Yanghan Wang committed
7
import tempfile
facebook-github-bot's avatar
facebook-github-bot committed
8
9
import unittest

10
import torch
11
12
from d2go.data.disk_cache import DiskCachedList, ROOT_CACHE_DIR
from d2go.data.utils import configure_dataset_creation
Yanghan Wang's avatar
Yanghan Wang committed
13
from d2go.runner import create_runner
14
from d2go.utils.testing.data_loader_helper import (
15
    create_detection_data_loader_on_toy_dataset,
16
17
    register_toy_coco_dataset,
)
facebook-github-bot's avatar
facebook-github-bot committed
18
19
20
21
22
23
24
25


class TestD2GoDatasetMapper(unittest.TestCase):
    """
    This class test D2GoDatasetMapper which is used to build
    data loader in GeneralizedRCNNRunner (the default runner) in Detectron2Go.
    """

Yanghan Wang's avatar
Yanghan Wang committed
26
27
28
29
    def setUp(self):
        self.output_dir = tempfile.mkdtemp(prefix="TestD2GoDatasetMapper_")
        self.addCleanup(shutil.rmtree, self.output_dir)

facebook-github-bot's avatar
facebook-github-bot committed
30
31
32
33
34
    def test_default_dataset(self):
        runner = create_runner("d2go.runner.GeneralizedRCNNRunner")
        cfg = runner.get_default_cfg()
        cfg.DATASETS.TRAIN = ["default_dataset_train"]
        cfg.DATASETS.TEST = ["default_dataset_test"]
Yanghan Wang's avatar
Yanghan Wang committed
35
        cfg.OUTPUT_DIR = self.output_dir
facebook-github-bot's avatar
facebook-github-bot committed
36

37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
        with register_toy_coco_dataset("default_dataset_train", num_images=3):
            train_loader = runner.build_detection_train_loader(cfg)
            for i, data in enumerate(train_loader):
                self.assertIsNotNone(data)
                # for training loader, it has infinite length
                if i == 6:
                    break

        with register_toy_coco_dataset("default_dataset_test", num_images=3):
            test_loader = runner.build_detection_test_loader(
                cfg, dataset_name="default_dataset_test"
            )
            all_data = []
            for data in test_loader:
                all_data.append(data)
            self.assertEqual(len(all_data), 3)
53
54
55
56
57
58
59
60
61
62
63
64


class _MyClass(object):
    def __init__(self, x):
        self.x = x

    def do_something(self):
        return


class TestDiskCachedDataLoader(unittest.TestCase):
    def setUp(self):
Yanghan Wang's avatar
Yanghan Wang committed
65
66
67
68
69
70
        # make sure the ROOT_CACHE_DIR is empty when entering the test
        if os.path.exists(ROOT_CACHE_DIR):
            shutil.rmtree(ROOT_CACHE_DIR)

        self.output_dir = tempfile.mkdtemp(prefix="TestDiskCachedDataLoader_")
        self.addCleanup(shutil.rmtree, self.output_dir)
71
72

    def _count_cache_dirs(self):
Yanghan Wang's avatar
Yanghan Wang committed
73
        if not os.path.exists(ROOT_CACHE_DIR):
74
75
            return 0

Yanghan Wang's avatar
Yanghan Wang committed
76
        return len(os.listdir(ROOT_CACHE_DIR))
77
78

    def test_disk_cached_dataset_from_list(self):
79
        """Test the class of DiskCachedList"""
80
81
        # check the discache can handel different data types
        lst = [1, torch.tensor(2), _MyClass(3)]
82
        disk_cached_lst = DiskCachedList(lst)
83
84
85
86
87
88
        self.assertEqual(len(disk_cached_lst), 3)
        self.assertEqual(disk_cached_lst[0], 1)
        self.assertEqual(disk_cached_lst[1].item(), 2)
        self.assertEqual(disk_cached_lst[2].x, 3)

        # check the cache is created
Yanghan Wang's avatar
Yanghan Wang committed
89
        cache_dir = disk_cached_lst.cache_dir
90
91
92
93
94
95
96
97
98
99
100
101
        self.assertTrue(os.path.isdir(cache_dir))

        # check the cache is properly released
        del disk_cached_lst
        self.assertFalse(os.path.isdir(cache_dir))

    def test_disk_cached_dataloader(self):
        """Test the data loader backed by disk cache"""
        height = 6
        width = 8
        runner = create_runner("d2go.runner.GeneralizedRCNNRunner")
        cfg = runner.get_default_cfg()
Yanghan Wang's avatar
Yanghan Wang committed
102
103
        cfg.OUTPUT_DIR = self.output_dir
        cfg.DATALOADER.NUM_WORKERS = 2
104
105
106
107
108
109
110
111

        def _test_data_loader(data_loader):
            first_batch = next(iter(data_loader))
            self.assertTrue(first_batch[0]["height"], height)
            self.assertTrue(first_batch[0]["width"], width)

        # enable the disk cache
        cfg.merge_from_list(["D2GO_DATA.DATASETS.DISK_CACHE.ENABLED", "True"])
112
        with configure_dataset_creation(cfg):
113
114
115
            # no cache dir in the beginning
            self.assertEqual(self._count_cache_dirs(), 0)

116
117
            with create_detection_data_loader_on_toy_dataset(
                cfg, height, width, is_train=True
118
119
120
121
122
123
            ) as train_loader:
                # train loader should create one cache dir
                self.assertEqual(self._count_cache_dirs(), 1)

                _test_data_loader(train_loader)

124
125
                with create_detection_data_loader_on_toy_dataset(
                    cfg, height, width, is_train=False
126
127
128
129
130
131
132
133
134
135
136
137
138
                ) as test_loader:
                    # test loader should create another cache dir
                    self.assertEqual(self._count_cache_dirs(), 2)

                    _test_data_loader(test_loader)

                # test loader should release its cache
                del test_loader
                self.assertEqual(self._count_cache_dirs(), 1)

            # no cache dir in the end
            del train_loader
            self.assertEqual(self._count_cache_dirs(), 0)