test_data_source.py 4.49 KB
Newer Older
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1
2
3
4
5
6
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

7
import contextlib
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
8
import os
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
9
import unittest
10
import unittest.mock
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
11

12
import torch
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
13
from omegaconf import OmegaConf
14
15
16
17
from pytorch3d.implicitron.dataset.data_loader_map_provider import (
    SequenceDataLoaderMapProvider,
    SimpleDataLoaderMapProvider,
)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
18
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
19
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
20
from pytorch3d.implicitron.tools.config import get_default_args
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
21
from tests.common_testing import get_tests_dir
22
from tests.implicitron.common_resources import get_skateboard_data
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
23
24
25
26
27
28
29
30

DATA_DIR = get_tests_dir() / "implicitron/data"
DEBUG: bool = False


class TestDataSource(unittest.TestCase):
    def setUp(self):
        self.maxDiff = None
31
        torch.manual_seed(42)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
32

33
34
35
36
37
38
        stack = contextlib.ExitStack()
        self.dataset_root, self.path_manager = stack.enter_context(
            get_skateboard_data()
        )
        self.addCleanup(stack.close)

39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
    def _test_omegaconf_generic_failure(self):
        # OmegaConf possible bug - this is why we need _GenericWorkaround
        from dataclasses import dataclass

        import torch

        @dataclass
        class D(torch.utils.data.Dataset[int]):
            a: int = 3

        OmegaConf.structured(D)

    def _test_omegaconf_ListList(self):
        # Demo that OmegaConf doesn't support nested lists
        from dataclasses import dataclass
        from typing import Sequence

        @dataclass
        class A:
            a: Sequence[Sequence[int]] = ((32,),)

        OmegaConf.structured(A)

    def test_JsonIndexDataset_args(self):
        # test that JsonIndexDataset works with get_default_args
        get_default_args(JsonIndexDataset)

Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
66
    def test_one(self):
67
68
69
70
        cfg = get_default_args(ImplicitronDataSource)
        # making the test invariant to env variables
        cfg.dataset_map_provider_JsonIndexDatasetMapProvider_args.dataset_root = ""
        cfg.dataset_map_provider_JsonIndexDatasetMapProviderV2_args.dataset_root = ""
71
72
73
        # making the test invariant to the presence of SQL dataset
        if "dataset_map_provider_SqlIndexDatasetMapProvider_args" in cfg:
            del cfg.dataset_map_provider_SqlIndexDatasetMapProvider_args
74
75
76
77
        yaml = OmegaConf.to_yaml(cfg, sort_keys=False)
        if DEBUG:
            (DATA_DIR / "data_source.yaml").write_text(yaml)
        self.assertEqual(yaml, (DATA_DIR / "data_source.yaml").read_text())
78
79
80
81
82
83
84
85
86
87
88

    def test_default(self):
        if os.environ.get("INSIDE_RE_WORKER") is not None:
            return
        args = get_default_args(ImplicitronDataSource)
        args.dataset_map_provider_class_type = "JsonIndexDatasetMapProvider"
        dataset_args = args.dataset_map_provider_JsonIndexDatasetMapProvider_args
        dataset_args.category = "skateboard"
        dataset_args.test_restrict_sequence_id = 0
        dataset_args.n_frames_per_sequence = -1

89
        dataset_args.dataset_root = self.dataset_root
90
91

        data_source = ImplicitronDataSource(**args)
92
93
94
        self.assertIsInstance(
            data_source.data_loader_map_provider, SequenceDataLoaderMapProvider
        )
95
96
97
98
99
        _, data_loaders = data_source.get_datasets_and_dataloaders()
        self.assertEqual(len(data_loaders.train), 81)
        for i in data_loaders.train:
            self.assertEqual(i.frame_type, ["test_known"])
            break
100
101
102
103
104
105
106
107
108
109
110
111

    def test_simple(self):
        if os.environ.get("INSIDE_RE_WORKER") is not None:
            return
        args = get_default_args(ImplicitronDataSource)
        args.dataset_map_provider_class_type = "JsonIndexDatasetMapProvider"
        args.data_loader_map_provider_class_type = "SimpleDataLoaderMapProvider"
        dataset_args = args.dataset_map_provider_JsonIndexDatasetMapProvider_args
        dataset_args.category = "skateboard"
        dataset_args.test_restrict_sequence_id = 0
        dataset_args.n_frames_per_sequence = -1

112
        dataset_args.dataset_root = self.dataset_root
113
114
115
116
117
118
119
120
121
122
123

        data_source = ImplicitronDataSource(**args)
        self.assertIsInstance(
            data_source.data_loader_map_provider, SimpleDataLoaderMapProvider
        )
        _, data_loaders = data_source.get_datasets_and_dataloaders()

        self.assertEqual(len(data_loaders.train), 81)
        for i in data_loaders.train:
            self.assertEqual(i.frame_type, ["test_known"])
            break