sunrgbd_dataset.py 1.47 KB
Newer Older
liyinhao's avatar
liyinhao committed
1
2
3
import numpy as np

from mmdet.datasets import DATASETS
zhangwenwei's avatar
zhangwenwei committed
4
from .custom_3d import Custom3DDataset
liyinhao's avatar
liyinhao committed
5
6
7


@DATASETS.register_module()
zhangwenwei's avatar
zhangwenwei committed
8
class SUNRGBDDataset(Custom3DDataset):
liyinhao's avatar
liyinhao committed
9

liyinhao's avatar
liyinhao committed
10
11
12
13
    CLASSES = ('bed', 'table', 'sofa', 'chair', 'toilet', 'desk', 'dresser',
               'night_stand', 'bookshelf', 'bathtub')

    def __init__(self,
zhangwenwei's avatar
zhangwenwei committed
14
                 data_root,
liyinhao's avatar
liyinhao committed
15
16
                 ann_file,
                 pipeline=None,
liyinhao's avatar
liyinhao committed
17
                 classes=None,
liyinhao's avatar
liyinhao committed
18
                 modality=None,
19
                 box_type_3d='Depth',
wuyuefeng's avatar
Votenet  
wuyuefeng committed
20
                 filter_empty_gt=True,
zhangwenwei's avatar
zhangwenwei committed
21
                 test_mode=False):
22
23
24
25
26
27
28
29
30
        super().__init__(
            data_root=data_root,
            ann_file=ann_file,
            pipeline=pipeline,
            classes=classes,
            modality=modality,
            box_type_3d=box_type_3d,
            filter_empty_gt=filter_empty_gt,
            test_mode=test_mode)
liyinhao's avatar
liyinhao committed
31

liyinhao's avatar
liyinhao committed
32
    def get_ann_info(self, index):
liyinhao's avatar
liyinhao committed
33
        # Use index to get the annos, thus the evalhook could also use this api
liyinhao's avatar
liyinhao committed
34
        info = self.data_infos[index]
liyinhao's avatar
liyinhao committed
35
        if info['annos']['gt_num'] != 0:
liyinhao's avatar
liyinhao committed
36
37
38
            gt_bboxes_3d = info['annos']['gt_boxes_upright_depth'].astype(
                np.float32)  # k, 6
            gt_labels_3d = info['annos']['class'].astype(np.long)
liyinhao's avatar
liyinhao committed
39
        else:
liyinhao's avatar
liyinhao committed
40
            gt_bboxes_3d = np.zeros((0, 7), dtype=np.float32)
liyinhao's avatar
liyinhao committed
41
            gt_labels_3d = np.zeros((0, ), dtype=np.long)
liyinhao's avatar
liyinhao committed
42
43

        anns_results = dict(
liyinhao's avatar
liyinhao committed
44
            gt_bboxes_3d=gt_bboxes_3d, gt_labels_3d=gt_labels_3d)
liyinhao's avatar
liyinhao committed
45
        return anns_results