sunrgbd_dataset.py 1.29 KB
Newer Older
liyinhao's avatar
liyinhao committed
1
2
3
4
5
import os.path as osp

import numpy as np

from mmdet.datasets import DATASETS
zhangwenwei's avatar
zhangwenwei committed
6
from .custom_3d import Custom3DDataset
liyinhao's avatar
liyinhao committed
7
8
9


@DATASETS.register_module()
zhangwenwei's avatar
zhangwenwei committed
10
class SUNRGBDDataset(Custom3DDataset):
liyinhao's avatar
liyinhao committed
11

liyinhao's avatar
liyinhao committed
12
13
14
15
    CLASSES = ('bed', 'table', 'sofa', 'chair', 'toilet', 'desk', 'dresser',
               'night_stand', 'bookshelf', 'bathtub')

    def __init__(self,
zhangwenwei's avatar
zhangwenwei committed
16
                 data_root,
liyinhao's avatar
liyinhao committed
17
18
                 ann_file,
                 pipeline=None,
liyinhao's avatar
liyinhao committed
19
                 classes=None,
zhangwenwei's avatar
zhangwenwei committed
20
21
                 test_mode=False):
        super().__init__(data_root, ann_file, pipeline, classes, test_mode)
liyinhao's avatar
liyinhao committed
22
23

    def _get_pts_filename(self, sample_idx):
zhangwenwei's avatar
zhangwenwei committed
24
        pts_filename = osp.join(self.data_root, 'lidar',
25
                                f'{sample_idx:06d}.npy')
liyinhao's avatar
liyinhao committed
26
27
        return pts_filename

liyinhao's avatar
liyinhao committed
28
    def get_ann_info(self, index):
liyinhao's avatar
liyinhao committed
29
        # Use index to get the annos, thus the evalhook could also use this api
liyinhao's avatar
liyinhao committed
30
        info = self.data_infos[index]
liyinhao's avatar
liyinhao committed
31
32
        if info['annos']['gt_num'] != 0:
            gt_bboxes_3d = info['annos']['gt_boxes_upright_depth']  # k, 6
zhangwenwei's avatar
zhangwenwei committed
33
            gt_labels_3d = info['annos']['class']
liyinhao's avatar
liyinhao committed
34
        else:
liyinhao's avatar
liyinhao committed
35
36
            gt_bboxes_3d = np.zeros((0, 7), dtype=np.float32)
            gt_labels_3d = np.zeros(0, )
liyinhao's avatar
liyinhao committed
37
38

        anns_results = dict(
liyinhao's avatar
liyinhao committed
39
            gt_bboxes_3d=gt_bboxes_3d, gt_labels_3d=gt_labels_3d)
liyinhao's avatar
liyinhao committed
40
        return anns_results