test_data_loader.py 4.44 KB
Newer Older
facebook-github-bot's avatar
facebook-github-bot committed
1
2
3
4
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved


5
6
import os
import shutil
facebook-github-bot's avatar
facebook-github-bot committed
7
8
import unittest

9
import torch
Yanghan Wang's avatar
Yanghan Wang committed
10
11
from d2go.data.disk_cache import DiskCachedDatasetFromList
from d2go.data.utils import enable_disk_cached_dataset
Yanghan Wang's avatar
Yanghan Wang committed
12
from d2go.runner import create_runner
13
14
15
16
from d2go.utils.testing.data_loader_helper import (
    create_fake_detection_data_loader,
    register_toy_coco_dataset,
)
facebook-github-bot's avatar
facebook-github-bot committed
17
18
19
20
21
22
23
24
25
26
27
28
29
30


class TestD2GoDatasetMapper(unittest.TestCase):
    """
    This class test D2GoDatasetMapper which is used to build
    data loader in GeneralizedRCNNRunner (the default runner) in Detectron2Go.
    """

    def test_default_dataset(self):
        runner = create_runner("d2go.runner.GeneralizedRCNNRunner")
        cfg = runner.get_default_cfg()
        cfg.DATASETS.TRAIN = ["default_dataset_train"]
        cfg.DATASETS.TEST = ["default_dataset_test"]

31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
        with register_toy_coco_dataset("default_dataset_train", num_images=3):
            train_loader = runner.build_detection_train_loader(cfg)
            for i, data in enumerate(train_loader):
                self.assertIsNotNone(data)
                # for training loader, it has infinite length
                if i == 6:
                    break

        with register_toy_coco_dataset("default_dataset_test", num_images=3):
            test_loader = runner.build_detection_test_loader(
                cfg, dataset_name="default_dataset_test"
            )
            all_data = []
            for data in test_loader:
                all_data.append(data)
            self.assertEqual(len(all_data), 3)
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127


class _MyClass(object):
    def __init__(self, x):
        self.x = x

    def do_something(self):
        return


class TestDiskCachedDataLoader(unittest.TestCase):
    def setUp(self):
        # make sure the CACHE_DIR is empty when entering the test
        if os.path.exists(DiskCachedDatasetFromList.CACHE_DIR):
            shutil.rmtree(DiskCachedDatasetFromList.CACHE_DIR)

    def _count_cache_dirs(self):
        if not os.path.exists(DiskCachedDatasetFromList.CACHE_DIR):
            return 0

        return len(os.listdir(DiskCachedDatasetFromList.CACHE_DIR))

    def test_disk_cached_dataset_from_list(self):
        """Test the class of DiskCachedDatasetFromList"""
        # check the discache can handel different data types
        lst = [1, torch.tensor(2), _MyClass(3)]
        disk_cached_lst = DiskCachedDatasetFromList(lst)
        self.assertEqual(len(disk_cached_lst), 3)
        self.assertEqual(disk_cached_lst[0], 1)
        self.assertEqual(disk_cached_lst[1].item(), 2)
        self.assertEqual(disk_cached_lst[2].x, 3)

        # check the cache is created
        cache_dir = disk_cached_lst.get_cache_dir()
        self.assertTrue(os.path.isdir(cache_dir))

        # check the cache is properly released
        del disk_cached_lst
        self.assertFalse(os.path.isdir(cache_dir))

    def test_disk_cached_dataloader(self):
        """Test the data loader backed by disk cache"""
        height = 6
        width = 8
        runner = create_runner("d2go.runner.GeneralizedRCNNRunner")
        cfg = runner.get_default_cfg()

        def _test_data_loader(data_loader):
            first_batch = next(iter(data_loader))
            self.assertTrue(first_batch[0]["height"], height)
            self.assertTrue(first_batch[0]["width"], width)

        # enable the disk cache
        cfg.merge_from_list(["D2GO_DATA.DATASETS.DISK_CACHE.ENABLED", "True"])
        with enable_disk_cached_dataset(cfg):
            # no cache dir in the beginning
            self.assertEqual(self._count_cache_dirs(), 0)

            with create_fake_detection_data_loader(
                height, width, is_train=True
            ) as train_loader:
                # train loader should create one cache dir
                self.assertEqual(self._count_cache_dirs(), 1)

                _test_data_loader(train_loader)

                with create_fake_detection_data_loader(
                    height, width, is_train=False
                ) as test_loader:
                    # test loader should create another cache dir
                    self.assertEqual(self._count_cache_dirs(), 2)

                    _test_data_loader(test_loader)

                # test loader should release its cache
                del test_loader
                self.assertEqual(self._count_cache_dirs(), 1)

            # no cache dir in the end
            del train_loader
            self.assertEqual(self._count_cache_dirs(), 0)