Unverified Commit f4ae6d75 authored by Xiang Xu's avatar Xiang Xu Committed by GitHub
Browse files

[Feature] Support lazy_init for `CBGSDataset` (#2227)

* refactor cbgs

* update cbgs

* add UT

* remove useless line

* Update test_dataset_wrappers.py

* update typehint

* update docs

* Update dataset_wrappers.py

* update
parent 50870e4c
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import warnings
from typing import List, Set, Union
import numpy as np
from mmengine.dataset import BaseDataset, force_full_init
from mmdet3d.registry import DATASETS
@DATASETS.register_module()
class CBGSDataset(object):
class CBGSDataset:
"""A wrapper of class sampled dataset with ann_file path. Implementation of
paper `Class-balanced Grouping and Sampling for Point Cloud 3D Object
Detection <https://arxiv.org/abs/1908.09492.>`_.
Detection <https://arxiv.org/abs/1908.09492>`_.
Balance the number of scenes under different classes.
Args:
dataset (:obj:`CustomDataset`): The dataset to be class sampled.
dataset (:obj:`BaseDataset` or dict): The dataset to be class sampled.
lazy_init (bool): Whether to load annotation during instantiation.
Defaults to False.
"""
def __init__(self, dataset):
self.dataset = DATASETS.build(dataset)
self.metainfo = self.dataset.metainfo
self.classes = self.metainfo['classes']
self.cat2id = {name: i for i, name in enumerate(self.classes)}
self.sample_indices = self._get_sample_indices()
# self.dataset.data_infos = self.data_infos
if hasattr(self.dataset, 'flag'):
self.flag = np.array(
[self.dataset.flag[ind] for ind in self.sample_indices],
dtype=np.uint8)
def _get_sample_indices(self):
"""Load annotations from ann_file.
def __init__(self,
dataset: Union[BaseDataset, dict],
lazy_init: bool = False) -> None:
self.dataset: BaseDataset
if isinstance(dataset, dict):
self.dataset = DATASETS.build(dataset)
elif isinstance(dataset, BaseDataset):
self.dataset = dataset
else:
raise TypeError(
'elements in datasets sequence should be config or '
f'`BaseDataset` instance, but got {type(dataset)}')
self._metainfo = self.dataset.metainfo
self._fully_initialized = False
if not lazy_init:
self.full_init()
@property
def metainfo(self) -> dict:
"""Get the meta information of the repeated dataset.
Returns:
dict: The meta information of repeated dataset.
"""
return copy.deepcopy(self._metainfo)
def full_init(self) -> None:
"""Loop to ``full_init`` each dataset."""
if self._fully_initialized:
return
self.dataset.full_init()
# Get sample_indices
self.sample_indices = self._get_sample_indices(self.dataset)
self._fully_initialized = True
def _get_sample_indices(self, dataset: BaseDataset) -> List[int]:
"""Load sample indices according to ann_file.
Args:
ann_file (str): Path of the annotation file.
dataset (:obj:`BaseDataset`): The dataset.
Returns:
list[dict]: List of annotations after class sampling.
List[dict]: List of indices after class sampling.
"""
class_sample_idxs = {cat_id: [] for cat_id in self.cat2id.values()}
for idx in range(len(self.dataset)):
sample_cat_ids = self.dataset.get_cat_ids(idx)
classes = self.metainfo['classes']
cat2id = {name: i for i, name in enumerate(classes)}
class_sample_idxs = {cat_id: [] for cat_id in cat2id.values()}
for idx in range(len(dataset)):
sample_cat_ids = dataset.get_cat_ids(idx)
for cat_id in sample_cat_ids:
if cat_id != -1:
# Filter categories that do not need to care.
# -1 indicate dontcare in MMDet3d.
# Filter categories that do not need to be cared.
# -1 indicates dontcare in MMDet3D.
class_sample_idxs[cat_id].append(idx)
duplicated_samples = sum(
[len(v) for _, v in class_sample_idxs.items()])
......@@ -54,7 +89,7 @@ class CBGSDataset(object):
sample_indices = []
frac = 1.0 / len(self.classes)
frac = 1.0 / len(classes)
ratios = [frac / v for v in class_distribution.values()]
for cls_inds, ratio in zip(list(class_sample_idxs.values()), ratios):
sample_indices += np.random.choice(cls_inds,
......@@ -62,19 +97,86 @@ class CBGSDataset(object):
ratio)).tolist()
return sample_indices
def __getitem__(self, idx):
@force_full_init
def _get_ori_dataset_idx(self, idx: int) -> int:
"""Convert global index to local index.
Args:
idx (int): Global index of ``CBGSDataset``.
Returns:
int: Local index of data.
"""
return self.sample_indices[idx]
@force_full_init
def get_cat_ids(self, idx: int) -> Set[int]:
"""Get category ids of class balanced dataset by index.
Args:
idx (int): Index of data.
Returns:
Set[int]: All categories in the sample of specified index.
"""
sample_idx = self._get_ori_dataset_idx(idx)
return self.dataset.get_cat_ids(sample_idx)
@force_full_init
def get_data_info(self, idx: int) -> dict:
"""Get annotation by index.
Args:
idx (int): Global index of ``CBGSDataset``.
Returns:
dict: The idx-th annotation of the dataset.
"""
sample_idx = self._get_ori_dataset_idx(idx)
return self.dataset.get_data_info(sample_idx)
def __getitem__(self, idx: int) -> dict:
"""Get item from infos according to the given index.
Args:
idx (int): The index of self.sample_indices.
Returns:
dict: Data dictionary of the corresponding index.
"""
ori_idx = self.sample_indices[idx]
return self.dataset[ori_idx]
if not self._fully_initialized:
warnings.warn('Please call `full_init` method manually to '
'accelerate the speed.')
self.full_init()
ori_index = self._get_ori_dataset_idx(idx)
return self.dataset[ori_index]
def __len__(self):
@force_full_init
def __len__(self) -> int:
"""Return the length of data infos.
Returns:
int: Length of data infos.
"""
return len(self.sample_indices)
def get_subset_(self, indices: Union[List[int], int]) -> None:
"""Not supported in ``CBGSDataset`` for the ambiguous meaning of sub-
dataset."""
raise NotImplementedError(
'`CBGSDataset` does not support `get_subset` and '
'`get_subset_` interfaces because this will lead to ambiguous '
'implementation of some methods. If you want to use `get_subset` '
'or `get_subset_` interfaces, please use them in the wrapped '
'dataset first and then use `CBGSDataset`.')
def get_subset(self, indices: Union[List[int], int]) -> BaseDataset:
"""Not supported in ``CBGSDataset`` for the ambiguous meaning of sub-
dataset."""
raise NotImplementedError(
'`CBGSDataset` does not support `get_subset` and '
'`get_subset_` interfaces because this will lead to ambiguous '
'implementation of some methods. If you want to use `get_subset` '
'or `get_subset_` interfaces, please use them in the wrapped '
'dataset first and then use `CBGSDataset`.')
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import numpy as np
import pytest
from mmcv.transforms.base import BaseTransform
from mmengine.structures import InstanceData
from mmdet3d.datasets import CBGSDataset, NuScenesDataset
from mmdet3d.registry import DATASETS, TRANSFORMS
from mmdet3d.structures import Det3DDataSample
def is_equal(dict_a, dict_b):
for key in dict_a:
if key not in dict_b:
return False
if isinstance(dict_a[key], dict):
return is_equal(dict_a[key], dict_b[key])
elif isinstance(dict_a[key], np.ndarray):
if not (dict_a[key] == dict_b[key]).any():
return False
else:
if not (dict_a[key] == dict_b[key]):
return False
return True
@TRANSFORMS.register_module()
class Identity(BaseTransform):
def transform(self, info):
packed_input = dict(data_samples=Det3DDataSample())
if 'ann_info' in info:
packed_input['data_samples'].gt_instances_3d = InstanceData()
packed_input['data_samples'].gt_instances_3d.labels_3d = info[
'ann_info']['gt_labels_3d']
return packed_input
@DATASETS.register_module()
class CustomDataset(NuScenesDataset):
pass
class TestCBGSDataset:
def setup(self):
dataset = NuScenesDataset
self.dataset = dataset(
data_root=osp.join(osp.dirname(__file__), '../data/nuscenes'),
ann_file='nus_info.pkl',
data_prefix=dict(
pts='samples/LIDAR_TOP', img='', sweeps='sweeps/LIDAR_TOP'),
pipeline=[dict(type=Identity)])
self.sample_indices = [0, 0, 1, 1, 1]
# test init
self.cbgs_datasets = CBGSDataset(dataset=self.dataset)
self.cbgs_datasets.sample_indices = self.sample_indices
def test_init(self):
# Test build dataset from cfg
dataset_cfg = dict(
type=CustomDataset,
data_root=osp.join(osp.dirname(__file__), '../data/nuscenes'),
ann_file='nus_info.pkl',
data_prefix=dict(
pts='samples/LIDAR_TOP', img='', sweeps='sweeps/LIDAR_TOP'),
pipeline=[dict(type=Identity)])
cbgs_datasets = CBGSDataset(dataset=dataset_cfg)
cbgs_datasets.sample_indices = self.sample_indices
cbgs_datasets.dataset.pipeline = self.dataset.pipeline
assert len(cbgs_datasets) == len(self.cbgs_datasets)
for i in range(len(cbgs_datasets)):
assert is_equal(
cbgs_datasets.get_data_info(i),
self.cbgs_datasets.get_data_info(i))
assert (cbgs_datasets[i]['data_samples'].gt_instances_3d.labels_3d
== self.cbgs_datasets[i]
['data_samples'].gt_instances_3d.labels_3d).any()
with pytest.raises(TypeError):
CBGSDataset(dataset=[0])
def test_full_init(self):
self.cbgs_datasets.full_init()
self.cbgs_datasets.sample_indices = self.sample_indices
assert len(self.cbgs_datasets) == len(self.sample_indices)
# Reinit `sample_indices`
self.cbgs_datasets._fully_initialized = False
self.cbgs_datasets.sample_indices = self.sample_indices
assert len(self.cbgs_datasets) != len(self.sample_indices)
with pytest.raises(NotImplementedError):
self.cbgs_datasets.get_subset_(1)
with pytest.raises(NotImplementedError):
self.cbgs_datasets.get_subset(1)
def test_metainfo(self):
assert self.cbgs_datasets.metainfo == self.dataset.metainfo
def test_length(self):
assert len(self.cbgs_datasets) == len(self.sample_indices)
def test_getitem(self):
for i in range(len(self.sample_indices)):
assert (self.cbgs_datasets[i]['data_samples'].gt_instances_3d.
labels_3d == self.dataset[self.sample_indices[i]]
['data_samples'].gt_instances_3d.labels_3d).any()
def test_get_data_info(self):
for i in range(len(self.sample_indices)):
assert is_equal(
self.cbgs_datasets.get_data_info(i),
self.dataset.get_data_info(self.sample_indices[i]))
def test_get_cat_ids(self):
for i in range(len(self.sample_indices)):
assert self.cbgs_datasets.get_cat_ids(
i) == self.dataset.get_cat_ids(self.sample_indices[i])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment