scannet_dataset.py 2.61 KB
Newer Older
1
2
3
4
5
6
import os.path as osp

import mmcv
import numpy as np

from mmdet.datasets import DATASETS
7
from .indoor_dataset import IndoorDataset
8
9
10


@DATASETS.register_module()
11
class ScannetDataset(IndoorDataset):
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
    class2type = {
        0: 'cabinet',
        1: 'bed',
        2: 'chair',
        3: 'sofa',
        4: 'table',
        5: 'door',
        6: 'window',
        7: 'bookshelf',
        8: 'picture',
        9: 'counter',
        10: 'desk',
        11: 'curtain',
        12: 'refrigerator',
        13: 'showercurtrain',
        14: 'toilet',
        15: 'sink',
        16: 'bathtub',
        17: 'garbagebin'
    }
    CLASSES = ('cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window',
               'bookshelf', 'picture', 'counter', 'desk', 'curtain',
               'refrigerator', 'showercurtrain', 'toilet', 'sink', 'bathtub',
               'garbagebin')

    def __init__(self,
                 root_path,
                 ann_file,
                 pipeline=None,
                 training=False,
                 class_names=None,
yinchimaoliang's avatar
yinchimaoliang committed
43
44
                 test_mode=False,
                 with_label=True):
45
46
        super().__init__(root_path, ann_file, pipeline, training, class_names,
                         test_mode, with_label)
47

liyinhao's avatar
liyinhao committed
48
        self.data_path = osp.join(root_path, 'scannet_train_instance_data')
49

yinchimaoliang's avatar
yinchimaoliang committed
50
    def _get_pts_filename(self, sample_idx):
51
        pts_filename = osp.join(self.data_path, f'{sample_idx}_vert.npy')
yinchimaoliang's avatar
yinchimaoliang committed
52
53
        mmcv.check_file_exist(pts_filename)
        return pts_filename
54

yinchimaoliang's avatar
yinchimaoliang committed
55
    def _get_ann_info(self, index, sample_idx):
56
        # Use index to get the annos, thus the evalhook could also use this api
57
        info = self.infos[index]
58
59
        if info['annos']['gt_num'] != 0:
            gt_bboxes_3d = info['annos']['gt_boxes_upright_depth']  # k, 6
yinchimaoliang's avatar
yinchimaoliang committed
60
61
            gt_labels = info['annos']['class']
            gt_bboxes_3d_mask = np.ones_like(gt_labels).astype(np.bool)
62
63
        else:
            gt_bboxes_3d = np.zeros((1, 6), dtype=np.float32)
yinchimaoliang's avatar
yinchimaoliang committed
64
65
            gt_labels = np.zeros(1, ).astype(np.bool)
            gt_bboxes_3d_mask = np.zeros(1, ).astype(np.bool)
66
        pts_instance_mask_path = osp.join(self.data_path,
liyinhao's avatar
liyinhao committed
67
                                          f'{sample_idx}_ins_label.npy')
68
        pts_semantic_mask_path = osp.join(self.data_path,
liyinhao's avatar
liyinhao committed
69
                                          f'{sample_idx}_sem_label.npy')
70
71
72
73
74
75
76
77

        anns_results = dict(
            gt_bboxes_3d=gt_bboxes_3d,
            gt_labels=gt_labels,
            gt_bboxes_3d_mask=gt_bboxes_3d_mask,
            pts_instance_mask_path=pts_instance_mask_path,
            pts_semantic_mask_path=pts_semantic_mask_path)
        return anns_results