Commit 729f65c9 authored by liyinhao's avatar liyinhao
Browse files

change names, mv indoor_eval outside

parent 43243598
from .class_names import dataset_aliases, get_classes, kitti_classes from .class_names import dataset_aliases, get_classes, kitti_classes
from .indoor_utils import indoor_eval from .indoor_eval import indoor_eval
from .kitti_utils import kitti_eval, kitti_eval_coco_style from .kitti_utils import kitti_eval, kitti_eval_coco_style
__all__ = [ __all__ = [
......
from .eval import indoor_eval
__all__ = ['indoor_eval']
from mmdet.datasets.builder import DATASETS from mmdet.datasets.builder import DATASETS
from .builder import build_dataset from .builder import build_dataset
from .dataset_wrappers import RepeatFactorDataset from .dataset_wrappers import RepeatFactorDataset
from .indoor_dataset import IndoorDataset from .indoor_base_dataset import IndoorBaseDataset
from .kitti2d_dataset import Kitti2DDataset from .kitti2d_dataset import Kitti2DDataset
from .kitti_dataset import KittiDataset from .kitti_dataset import KittiDataset
from .loader import DistributedGroupSampler, GroupSampler, build_dataloader from .loader import DistributedGroupSampler, GroupSampler, build_dataloader
...@@ -12,8 +12,8 @@ from .pipelines import (GlobalRotScale, IndoorFlipData, IndoorGlobalRotScale, ...@@ -12,8 +12,8 @@ from .pipelines import (GlobalRotScale, IndoorFlipData, IndoorGlobalRotScale,
IndoorPointsColorNormalize, ObjectNoise, IndoorPointsColorNormalize, ObjectNoise,
ObjectRangeFilter, ObjectSample, PointShuffle, ObjectRangeFilter, ObjectSample, PointShuffle,
PointsRangeFilter, RandomFlip3D) PointsRangeFilter, RandomFlip3D)
from .scannet_dataset import ScannetDataset from .scannet_dataset import ScannetBaseDataset
from .sunrgbd_dataset import SunrgbdDataset from .sunrgbd_dataset import SunrgbdBaseDataset
__all__ = [ __all__ = [
'KittiDataset', 'GroupSampler', 'DistributedGroupSampler', 'KittiDataset', 'GroupSampler', 'DistributedGroupSampler',
...@@ -23,6 +23,6 @@ __all__ = [ ...@@ -23,6 +23,6 @@ __all__ = [
'ObjectRangeFilter', 'PointsRangeFilter', 'Collect3D', 'ObjectRangeFilter', 'PointsRangeFilter', 'Collect3D',
'IndoorLoadPointsFromFile', 'IndoorPointsColorNormalize', 'IndoorLoadPointsFromFile', 'IndoorPointsColorNormalize',
'IndoorPointSample', 'IndoorLoadAnnotations3D', 'IndoorPointsColorJitter', 'IndoorPointSample', 'IndoorLoadAnnotations3D', 'IndoorPointsColorJitter',
'IndoorGlobalRotScale', 'IndoorFlipData', 'SunrgbdDataset', 'IndoorGlobalRotScale', 'IndoorFlipData', 'SunrgbdBaseDataset',
'ScannetDataset', 'IndoorDataset' 'ScannetBaseDataset', 'IndoorBaseDataset'
] ]
...@@ -9,28 +9,28 @@ from .pipelines import Compose ...@@ -9,28 +9,28 @@ from .pipelines import Compose
@DATASETS.register_module() @DATASETS.register_module()
class IndoorDataset(torch_data.Dataset): class IndoorBaseDataset(torch_data.Dataset):
def __init__(self, def __init__(self,
root_path, root_path,
ann_file, ann_file,
pipeline=None, pipeline=None,
training=False, training=False,
class_names=None, cat_ids=None,
test_mode=False, test_mode=False,
with_label=True): with_label=True):
super().__init__() super().__init__()
self.root_path = root_path self.root_path = root_path
self.class_names = class_names if class_names else self.CLASSES self.cat_ids = cat_ids if cat_ids else self.CLASSES
self.test_mode = test_mode self.test_mode = test_mode
self.training = training self.training = training
self.mode = 'TRAIN' if self.training else 'TEST' self.mode = 'TRAIN' if self.training else 'TEST'
self.label2cat = {i: cat_id for i, cat_id in enumerate(self.cat_ids)}
mmcv.check_file_exist(ann_file) mmcv.check_file_exist(ann_file)
self.infos = mmcv.load(ann_file) self.infos = mmcv.load(ann_file)
# dataset config # dataset config
self.num_class = len(self.class_names) self.num_class = len(self.cat_ids)
self.pcd_limit_range = [0, -40, -3.0, 70.4, 40, 3.0] self.pcd_limit_range = [0, -40, -3.0, 70.4, 40, 3.0]
if pipeline is not None: if pipeline is not None:
self.pipeline = Compose(pipeline) self.pipeline = Compose(pipeline)
...@@ -134,7 +134,7 @@ class IndoorDataset(torch_data.Dataset): ...@@ -134,7 +134,7 @@ class IndoorDataset(torch_data.Dataset):
assert len(metric) > 0 assert len(metric) > 0
gt_annos = [copy.deepcopy(info['annos']) for info in self.infos] gt_annos = [copy.deepcopy(info['annos']) for info in self.infos]
ap_result_str, ap_dict = indoor_eval(gt_annos, results, metric, ap_result_str, ap_dict = indoor_eval(gt_annos, results, metric,
self.class2type) self.label2cat)
return ap_dict return ap_dict
def __len__(self): def __len__(self):
......
...@@ -4,31 +4,12 @@ import mmcv ...@@ -4,31 +4,12 @@ import mmcv
import numpy as np import numpy as np
from mmdet.datasets import DATASETS from mmdet.datasets import DATASETS
from .indoor_dataset import IndoorDataset from .indoor_base_dataset import IndoorBaseDataset
@DATASETS.register_module() @DATASETS.register_module()
class ScannetDataset(IndoorDataset): class ScannetBaseDataset(IndoorBaseDataset):
class2type = {
0: 'cabinet',
1: 'bed',
2: 'chair',
3: 'sofa',
4: 'table',
5: 'door',
6: 'window',
7: 'bookshelf',
8: 'picture',
9: 'counter',
10: 'desk',
11: 'curtain',
12: 'refrigerator',
13: 'showercurtrain',
14: 'toilet',
15: 'sink',
16: 'bathtub',
17: 'garbagebin'
}
CLASSES = ('cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', CLASSES = ('cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window',
'bookshelf', 'picture', 'counter', 'desk', 'curtain', 'bookshelf', 'picture', 'counter', 'desk', 'curtain',
'refrigerator', 'showercurtrain', 'toilet', 'sink', 'bathtub', 'refrigerator', 'showercurtrain', 'toilet', 'sink', 'bathtub',
...@@ -39,10 +20,10 @@ class ScannetDataset(IndoorDataset): ...@@ -39,10 +20,10 @@ class ScannetDataset(IndoorDataset):
ann_file, ann_file,
pipeline=None, pipeline=None,
training=False, training=False,
class_names=None, cat_ids=None,
test_mode=False, test_mode=False,
with_label=True): with_label=True):
super().__init__(root_path, ann_file, pipeline, training, class_names, super().__init__(root_path, ann_file, pipeline, training, cat_ids,
test_mode, with_label) test_mode, with_label)
self.data_path = osp.join(root_path, 'scannet_train_instance_data') self.data_path = osp.join(root_path, 'scannet_train_instance_data')
......
...@@ -4,24 +4,12 @@ import mmcv ...@@ -4,24 +4,12 @@ import mmcv
import numpy as np import numpy as np
from mmdet.datasets import DATASETS from mmdet.datasets import DATASETS
from .indoor_dataset import IndoorDataset from .indoor_base_dataset import IndoorBaseDataset
@DATASETS.register_module() @DATASETS.register_module()
class SunrgbdDataset(IndoorDataset): class SunrgbdBaseDataset(IndoorBaseDataset):
class2type = {
0: 'bed',
1: 'table',
2: 'sofa',
3: 'chair',
4: 'toilet',
5: 'desk',
6: 'dresser',
7: 'night_stand',
8: 'bookshelf',
9: 'bathtub'
}
CLASSES = ('bed', 'table', 'sofa', 'chair', 'toilet', 'desk', 'dresser', CLASSES = ('bed', 'table', 'sofa', 'chair', 'toilet', 'desk', 'dresser',
'night_stand', 'bookshelf', 'bathtub') 'night_stand', 'bookshelf', 'bathtub')
...@@ -30,10 +18,10 @@ class SunrgbdDataset(IndoorDataset): ...@@ -30,10 +18,10 @@ class SunrgbdDataset(IndoorDataset):
ann_file, ann_file,
pipeline=None, pipeline=None,
training=False, training=False,
class_names=None, cat_ids=None,
test_mode=False, test_mode=False,
with_label=True): with_label=True):
super().__init__(root_path, ann_file, pipeline, training, class_names, super().__init__(root_path, ann_file, pipeline, training, cat_ids,
test_mode, with_label) test_mode, with_label)
self.data_path = osp.join(root_path, 'sunrgbd_trainval') self.data_path = osp.join(root_path, 'sunrgbd_trainval')
......
...@@ -2,7 +2,7 @@ import numpy as np ...@@ -2,7 +2,7 @@ import numpy as np
import pytest import pytest
import torch import torch
from mmdet3d.datasets import ScannetDataset from mmdet3d.datasets import ScannetBaseDataset
def test_getitem(): def test_getitem():
...@@ -36,7 +36,7 @@ def test_getitem(): ...@@ -36,7 +36,7 @@ def test_getitem():
]), ]),
] ]
scannet_dataset = ScannetDataset(root_path, ann_file, pipelines, True) scannet_dataset = ScannetBaseDataset(root_path, ann_file, pipelines, True)
data = scannet_dataset[0] data = scannet_dataset[0]
points = data['points']._data points = data['points']._data
gt_bboxes_3d = data['gt_bboxes_3d']._data gt_bboxes_3d = data['gt_bboxes_3d']._data
...@@ -77,7 +77,7 @@ def test_evaluate(): ...@@ -77,7 +77,7 @@ def test_evaluate():
pytest.skip() pytest.skip()
root_path = './tests/data/scannet' root_path = './tests/data/scannet'
ann_file = './tests/data/scannet/scannet_infos.pkl' ann_file = './tests/data/scannet/scannet_infos.pkl'
scannet_dataset = ScannetDataset(root_path, ann_file) scannet_dataset = ScannetBaseDataset(root_path, ann_file)
results = [] results = []
pred_boxes = dict() pred_boxes = dict()
pred_boxes['box3d_lidar'] = np.array([[ pred_boxes['box3d_lidar'] = np.array([[
......
...@@ -2,7 +2,7 @@ import numpy as np ...@@ -2,7 +2,7 @@ import numpy as np
import pytest import pytest
import torch import torch
from mmdet3d.datasets import SunrgbdDataset from mmdet3d.datasets import SunrgbdBaseDataset
def test_getitem(): def test_getitem():
...@@ -28,7 +28,7 @@ def test_getitem(): ...@@ -28,7 +28,7 @@ def test_getitem():
dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels']), dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels']),
] ]
sunrgbd_dataset = SunrgbdDataset(root_path, ann_file, pipelines, True) sunrgbd_dataset = SunrgbdBaseDataset(root_path, ann_file, pipelines, True)
data = sunrgbd_dataset[0] data = sunrgbd_dataset[0]
points = data['points']._data points = data['points']._data
gt_bboxes_3d = data['gt_bboxes_3d']._data gt_bboxes_3d = data['gt_bboxes_3d']._data
...@@ -67,7 +67,7 @@ def test_evaluate(): ...@@ -67,7 +67,7 @@ def test_evaluate():
pytest.skip() pytest.skip()
root_path = './tests/data/sunrgbd' root_path = './tests/data/sunrgbd'
ann_file = './tests/data/sunrgbd/sunrgbd_infos.pkl' ann_file = './tests/data/sunrgbd/sunrgbd_infos.pkl'
sunrgbd_dataset = SunrgbdDataset(root_path, ann_file) sunrgbd_dataset = SunrgbdBaseDataset(root_path, ann_file)
results = [] results = []
pred_boxes = dict() pred_boxes = dict()
pred_boxes['box3d_lidar'] = np.array([[ pred_boxes['box3d_lidar'] = np.array([[
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment