dataset_wrappers.py 6.21 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
3
4
5
import copy
import warnings
from typing import List, Set, Union

yinchimaoliang's avatar
yinchimaoliang committed
6
import numpy as np
7
from mmengine.dataset import BaseDataset, force_full_init
yinchimaoliang's avatar
yinchimaoliang committed
8

9
from mmdet3d.registry import DATASETS
yinchimaoliang's avatar
yinchimaoliang committed
10
11
12


@DATASETS.register_module()
13
class CBGSDataset:
yinchimaoliang's avatar
yinchimaoliang committed
14
15
    """A wrapper of class sampled dataset with ann_file path. Implementation of
    paper `Class-balanced Grouping and Sampling for Point Cloud 3D Object
16
    Detection <https://arxiv.org/abs/1908.09492>`_.
yinchimaoliang's avatar
yinchimaoliang committed
17
18
19
20

    Balance the number of scenes under different classes.

    Args:
21
22
23
        dataset (:obj:`BaseDataset` or dict): The dataset to be class sampled.
        lazy_init (bool): Whether to load annotation during instantiation.
            Defaults to False.
yinchimaoliang's avatar
yinchimaoliang committed
24
25
    """

26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
    def __init__(self,
                 dataset: Union[BaseDataset, dict],
                 lazy_init: bool = False) -> None:
        self.dataset: BaseDataset
        if isinstance(dataset, dict):
            self.dataset = DATASETS.build(dataset)
        elif isinstance(dataset, BaseDataset):
            self.dataset = dataset
        else:
            raise TypeError(
                'elements in datasets sequence should be config or '
                f'`BaseDataset` instance, but got {type(dataset)}')
        self._metainfo = self.dataset.metainfo

        self._fully_initialized = False
        if not lazy_init:
            self.full_init()

    @property
    def metainfo(self) -> dict:
        """Get the meta information of the repeated dataset.

        Returns:
            dict: The meta information of repeated dataset.
        """
        return copy.deepcopy(self._metainfo)

    def full_init(self) -> None:
        """Loop to ``full_init`` each dataset."""
        if self._fully_initialized:
            return

        self.dataset.full_init()
        # Get sample_indices
        self.sample_indices = self._get_sample_indices(self.dataset)

        self._fully_initialized = True

    def _get_sample_indices(self, dataset: BaseDataset) -> List[int]:
        """Load sample indices according to ann_file.
yinchimaoliang's avatar
yinchimaoliang committed
66
67

        Args:
68
            dataset (:obj:`BaseDataset`): The dataset.
yinchimaoliang's avatar
yinchimaoliang committed
69
70

        Returns:
71
            List[dict]: List of indices after class sampling.
yinchimaoliang's avatar
yinchimaoliang committed
72
        """
73
74
75
76
77
        classes = self.metainfo['classes']
        cat2id = {name: i for i, name in enumerate(classes)}
        class_sample_idxs = {cat_id: [] for cat_id in cat2id.values()}
        for idx in range(len(dataset)):
            sample_cat_ids = dataset.get_cat_ids(idx)
78
            for cat_id in sample_cat_ids:
79
                if cat_id != -1:
80
81
                    # Filter categories that do not need to be cared.
                    # -1 indicates dontcare in MMDet3D.
82
                    class_sample_idxs[cat_id].append(idx)
83
84
        duplicated_samples = sum(
            [len(v) for _, v in class_sample_idxs.items()])
yinchimaoliang's avatar
yinchimaoliang committed
85
86
        class_distribution = {
            k: len(v) / duplicated_samples
87
            for k, v in class_sample_idxs.items()
yinchimaoliang's avatar
yinchimaoliang committed
88
89
90
91
        }

        sample_indices = []

92
        frac = 1.0 / len(classes)
yinchimaoliang's avatar
yinchimaoliang committed
93
        ratios = [frac / v for v in class_distribution.values()]
94
        for cls_inds, ratio in zip(list(class_sample_idxs.values()), ratios):
yinchimaoliang's avatar
yinchimaoliang committed
95
96
97
98
99
            sample_indices += np.random.choice(cls_inds,
                                               int(len(cls_inds) *
                                                   ratio)).tolist()
        return sample_indices

100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
    @force_full_init
    def _get_ori_dataset_idx(self, idx: int) -> int:
        """Convert global index to local index.

        Args:
            idx (int): Global index of ``CBGSDataset``.

        Returns:
            int: Local index of data.
        """
        return self.sample_indices[idx]

    @force_full_init
    def get_cat_ids(self, idx: int) -> Set[int]:
        """Get category ids of class balanced dataset by index.

        Args:
            idx (int): Index of data.

        Returns:
            Set[int]: All categories in the sample of specified index.
        """
        sample_idx = self._get_ori_dataset_idx(idx)
        return self.dataset.get_cat_ids(sample_idx)

    @force_full_init
    def get_data_info(self, idx: int) -> dict:
        """Get annotation by index.

        Args:
            idx (int): Global index of ``CBGSDataset``.

        Returns:
            dict: The idx-th annotation of the dataset.
        """
        sample_idx = self._get_ori_dataset_idx(idx)
        return self.dataset.get_data_info(sample_idx)

    def __getitem__(self, idx: int) -> dict:
yinchimaoliang's avatar
yinchimaoliang committed
139
140
        """Get item from infos according to the given index.

141
142
143
        Args:
            idx (int): The index of self.sample_indices.

yinchimaoliang's avatar
yinchimaoliang committed
144
145
146
        Returns:
            dict: Data dictionary of the corresponding index.
        """
147
148
149
150
151
152
153
        if not self._fully_initialized:
            warnings.warn('Please call `full_init` method manually to '
                          'accelerate the speed.')
            self.full_init()

        ori_index = self._get_ori_dataset_idx(idx)
        return self.dataset[ori_index]
yinchimaoliang's avatar
yinchimaoliang committed
154

155
156
    @force_full_init
    def __len__(self) -> int:
yinchimaoliang's avatar
yinchimaoliang committed
157
158
159
160
161
162
        """Return the length of data infos.

        Returns:
            int: Length of data infos.
        """
        return len(self.sample_indices)
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182

    def get_subset_(self, indices: Union[List[int], int]) -> None:
        """Not supported in ``CBGSDataset`` for the ambiguous meaning of sub-
        dataset."""
        raise NotImplementedError(
            '`CBGSDataset` does not support `get_subset` and '
            '`get_subset_` interfaces because this will lead to ambiguous '
            'implementation of some methods. If you want to use `get_subset` '
            'or `get_subset_` interfaces, please use them in the wrapped '
            'dataset first and then use `CBGSDataset`.')

    def get_subset(self, indices: Union[List[int], int]) -> BaseDataset:
        """Not supported in ``CBGSDataset`` for the ambiguous meaning of sub-
        dataset."""
        raise NotImplementedError(
            '`CBGSDataset` does not support `get_subset` and '
            '`get_subset_` interfaces because this will lead to ambiguous '
            'implementation of some methods. If you want to use `get_subset` '
            'or `get_subset_` interfaces, please use them in the wrapped '
            'dataset first and then use `CBGSDataset`.')