sunrgbd_dataset.py 1.74 KB
Newer Older
liyinhao's avatar
liyinhao committed
1
2
3
4
5
6
import os.path as osp

import mmcv
import numpy as np

from mmdet.datasets import DATASETS
7
from .indoor_base_dataset import IndoorBaseDataset
liyinhao's avatar
liyinhao committed
8
9
10


@DATASETS.register_module()
11
class SunrgbdBaseDataset(IndoorBaseDataset):
liyinhao's avatar
liyinhao committed
12

liyinhao's avatar
liyinhao committed
13
14
15
16
17
18
19
20
    CLASSES = ('bed', 'table', 'sofa', 'chair', 'toilet', 'desk', 'dresser',
               'night_stand', 'bookshelf', 'bathtub')

    def __init__(self,
                 root_path,
                 ann_file,
                 pipeline=None,
                 training=False,
21
                 cat_ids=None,
liyinhao's avatar
liyinhao committed
22
23
                 test_mode=False,
                 with_label=True):
24
        super().__init__(root_path, ann_file, pipeline, training, cat_ids,
25
                         test_mode, with_label)
liyinhao's avatar
liyinhao committed
26
27
28
        self.data_path = osp.join(root_path, 'sunrgbd_trainval')

    def _get_pts_filename(self, sample_idx):
29
30
        pts_filename = osp.join(self.data_path, 'lidar',
                                f'{sample_idx:06d}.npy')
liyinhao's avatar
liyinhao committed
31
32
33
34
35
        mmcv.check_file_exist(pts_filename)
        return pts_filename

    def _get_ann_info(self, index, sample_idx):
        # Use index to get the annos, thus the evalhook could also use this api
36
        info = self.infos[index]
liyinhao's avatar
liyinhao committed
37
38
39
40
41
42
43
44
45
46
47
48
49
50
        if info['annos']['gt_num'] != 0:
            gt_bboxes_3d = info['annos']['gt_boxes_upright_depth']  # k, 6
            gt_labels = info['annos']['class']
            gt_bboxes_3d_mask = np.ones_like(gt_labels).astype(np.bool)
        else:
            gt_bboxes_3d = np.zeros((1, 6), dtype=np.float32)
            gt_labels = np.zeros(1, ).astype(np.bool)
            gt_bboxes_3d_mask = np.zeros(1, ).astype(np.bool)

        anns_results = dict(
            gt_bboxes_3d=gt_bboxes_3d,
            gt_labels=gt_labels,
            gt_bboxes_3d_mask=gt_bboxes_3d_mask)
        return anns_results