scannet_dataset.py 1.87 KB
Newer Older
1
2
3
4
5
import os.path as osp

import numpy as np

from mmdet.datasets import DATASETS
zhangwenwei's avatar
zhangwenwei committed
6
from .custom_3d import Custom3DDataset
7
8
9


@DATASETS.register_module()
zhangwenwei's avatar
zhangwenwei committed
10
class ScanNetDataset(Custom3DDataset):
11

12
13
14
15
16
17
    CLASSES = ('cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window',
               'bookshelf', 'picture', 'counter', 'desk', 'curtain',
               'refrigerator', 'showercurtrain', 'toilet', 'sink', 'bathtub',
               'garbagebin')

    def __init__(self,
zhangwenwei's avatar
zhangwenwei committed
18
                 data_root,
19
20
                 ann_file,
                 pipeline=None,
liyinhao's avatar
liyinhao committed
21
                 classes=None,
zhangwenwei's avatar
zhangwenwei committed
22
23
                 test_mode=False):
        super().__init__(data_root, ann_file, pipeline, classes, test_mode)
24

yinchimaoliang's avatar
yinchimaoliang committed
25
    def _get_pts_filename(self, sample_idx):
zhangwenwei's avatar
zhangwenwei committed
26
        pts_filename = osp.join(self.data_root, f'{sample_idx}_vert.npy')
yinchimaoliang's avatar
yinchimaoliang committed
27
        return pts_filename
28

liyinhao's avatar
liyinhao committed
29
    def get_ann_info(self, index):
30
        # Use index to get the annos, thus the evalhook could also use this api
liyinhao's avatar
liyinhao committed
31
        info = self.data_infos[index]
32
        if info['annos']['gt_num'] != 0:
liyinhao's avatar
liyinhao committed
33
34
35
            gt_bboxes_3d = info['annos']['gt_boxes_upright_depth'].astype(
                np.float32)  # k, 6
            gt_labels_3d = info['annos']['class'].astype(np.long)
36
        else:
liyinhao's avatar
liyinhao committed
37
            gt_bboxes_3d = np.zeros((0, 6), dtype=np.float32)
liyinhao's avatar
liyinhao committed
38
            gt_labels_3d = np.zeros((0, ), dtype=np.long)
liyinhao's avatar
liyinhao committed
39
        sample_idx = info['point_cloud']['lidar_idx']
zhangwenwei's avatar
zhangwenwei committed
40
        pts_instance_mask_path = osp.join(self.data_root,
liyinhao's avatar
liyinhao committed
41
                                          f'{sample_idx}_ins_label.npy')
zhangwenwei's avatar
zhangwenwei committed
42
        pts_semantic_mask_path = osp.join(self.data_root,
liyinhao's avatar
liyinhao committed
43
                                          f'{sample_idx}_sem_label.npy')
44
45
46

        anns_results = dict(
            gt_bboxes_3d=gt_bboxes_3d,
zhangwenwei's avatar
zhangwenwei committed
47
            gt_labels_3d=gt_labels_3d,
48
49
50
            pts_instance_mask_path=pts_instance_mask_path,
            pts_semantic_mask_path=pts_semantic_mask_path)
        return anns_results