Commit 2a7030a5 authored by VVsssssk's avatar VVsssssk Committed by ChaimZhu
Browse files

[Refactor]Support classes balance dataset

parent c66197c7
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .builder import DATASETS, PIPELINES, build_dataset from .builder import DATASETS, PIPELINES, build_dataset
from .dataset_wrappers import CBGSDataset
from .det3d_dataset import Det3DDataset from .det3d_dataset import Det3DDataset
from .kitti_dataset import KittiDataset from .kitti_dataset import KittiDataset
from .kitti_mono_dataset import KittiMonoDataset from .kitti_mono_dataset import KittiMonoDataset
...@@ -28,17 +29,17 @@ from .utils import get_loading_pipeline ...@@ -28,17 +29,17 @@ from .utils import get_loading_pipeline
from .waymo_dataset import WaymoDataset from .waymo_dataset import WaymoDataset
__all__ = [ __all__ = [
'KittiDataset', 'KittiMonoDataset', 'DATASETS', 'build_dataset', 'KittiDataset', 'KittiMonoDataset', 'DATASETS', 'CBGSDataset',
'NuScenesDataset', 'NuScenesMonoDataset', 'LyftDataset', 'ObjectSample', 'build_dataset', 'NuScenesDataset', 'NuScenesMonoDataset', 'LyftDataset',
'RandomFlip3D', 'ObjectNoise', 'GlobalRotScaleTrans', 'PointShuffle', 'ObjectSample', 'RandomFlip3D', 'ObjectNoise', 'GlobalRotScaleTrans',
'ObjectRangeFilter', 'PointsRangeFilter', 'LoadPointsFromFile', 'PointShuffle', 'ObjectRangeFilter', 'PointsRangeFilter',
'S3DISSegDataset', 'S3DISDataset', 'NormalizePointsColor', 'LoadPointsFromFile', 'S3DISSegDataset', 'S3DISDataset',
'IndoorPatchPointSample', 'IndoorPointSample', 'PointSample', 'NormalizePointsColor', 'IndoorPatchPointSample', 'IndoorPointSample',
'LoadAnnotations3D', 'GlobalAlignment', 'SUNRGBDDataset', 'ScanNetDataset', 'PointSample', 'LoadAnnotations3D', 'GlobalAlignment', 'SUNRGBDDataset',
'ScanNetSegDataset', 'ScanNetInstanceSegDataset', 'SemanticKITTIDataset', 'ScanNetDataset', 'ScanNetSegDataset', 'ScanNetInstanceSegDataset',
'Det3DDataset', 'Seg3DDataset', 'LoadPointsFromMultiSweeps', 'SemanticKITTIDataset', 'Det3DDataset', 'Seg3DDataset',
'WaymoDataset', 'BackgroundPointsFilter', 'VoxelBasedPointSampler', 'LoadPointsFromMultiSweeps', 'WaymoDataset', 'BackgroundPointsFilter',
'get_loading_pipeline', 'RandomDropPointsColor', 'RandomJitterPoints', 'VoxelBasedPointSampler', 'get_loading_pipeline', 'RandomDropPointsColor',
'ObjectNameFilter', 'AffineResize', 'RandomShiftScale', 'RandomJitterPoints', 'ObjectNameFilter', 'AffineResize',
'LoadPointsFromDict', 'PIPELINES' 'RandomShiftScale', 'LoadPointsFromDict', 'PIPELINES'
] ]
...@@ -17,8 +17,8 @@ class CBGSDataset(object): ...@@ -17,8 +17,8 @@ class CBGSDataset(object):
""" """
def __init__(self, dataset): def __init__(self, dataset):
self.dataset = dataset self.dataset = DATASETS.build(dataset)
self.CLASSES = dataset.CLASSES self.CLASSES = self.dataset.metainfo['CLASSES']
self.cat2id = {name: i for i, name in enumerate(self.CLASSES)} self.cat2id = {name: i for i, name in enumerate(self.CLASSES)}
self.sample_indices = self._get_sample_indices() self.sample_indices = self._get_sample_indices()
# self.dataset.data_infos = self.data_infos # self.dataset.data_infos = self.data_infos
...@@ -40,6 +40,9 @@ class CBGSDataset(object): ...@@ -40,6 +40,9 @@ class CBGSDataset(object):
for idx in range(len(self.dataset)): for idx in range(len(self.dataset)):
sample_cat_ids = self.dataset.get_cat_ids(idx) sample_cat_ids = self.dataset.get_cat_ids(idx)
for cat_id in sample_cat_ids: for cat_id in sample_cat_ids:
if cat_id != -1:
# Filter categories that do not need to care.
# -1 indicate dontcare in MMDet3d.
class_sample_idxs[cat_id].append(idx) class_sample_idxs[cat_id].append(idx)
duplicated_samples = sum( duplicated_samples = sum(
[len(v) for _, v in class_sample_idxs.items()]) [len(v) for _, v in class_sample_idxs.items()])
......
...@@ -294,3 +294,20 @@ class Det3DDataset(BaseDataset): ...@@ -294,3 +294,20 @@ class Det3DDataset(BaseDataset):
example['data_sample'].gt_instances_3d.labels_3d) == 0: example['data_sample'].gt_instances_3d.labels_3d) == 0:
return None return None
return example return example
def get_cat_ids(self, idx: int) -> List[int]:
"""Get category ids by index. Dataset wrapped by ClassBalancedDataset
must implement this method.
The ``CBGSDataset`` or ``ClassBalancedDataset``requires a subclass
which implements this method.
Args:
idx (int): The index of data.
Returns:
set[int]: All categories in the sample of specified index.
"""
info = self.get_data_info(idx)
gt_labels = info['ann_info']['gt_labels_3d'].tolist()
return set(gt_labels)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment