sunrgbd_dataset.py 2.56 KB
Newer Older
liyinhao's avatar
liyinhao committed
1
2
import os.path as osp

liyinhao's avatar
liyinhao committed
3
4
import numpy as np

liyinhao's avatar
liyinhao committed
5
from mmdet3d.core import show_result
wuyuefeng's avatar
wuyuefeng committed
6
from mmdet3d.core.bbox import DepthInstance3DBoxes
liyinhao's avatar
liyinhao committed
7
from mmdet.datasets import DATASETS
zhangwenwei's avatar
zhangwenwei committed
8
from .custom_3d import Custom3DDataset
liyinhao's avatar
liyinhao committed
9
10
11


@DATASETS.register_module()
zhangwenwei's avatar
zhangwenwei committed
12
class SUNRGBDDataset(Custom3DDataset):
liyinhao's avatar
liyinhao committed
13

liyinhao's avatar
liyinhao committed
14
15
16
17
    CLASSES = ('bed', 'table', 'sofa', 'chair', 'toilet', 'desk', 'dresser',
               'night_stand', 'bookshelf', 'bathtub')

    def __init__(self,
zhangwenwei's avatar
zhangwenwei committed
18
                 data_root,
liyinhao's avatar
liyinhao committed
19
20
                 ann_file,
                 pipeline=None,
liyinhao's avatar
liyinhao committed
21
                 classes=None,
liyinhao's avatar
liyinhao committed
22
                 modality=None,
23
                 box_type_3d='Depth',
wuyuefeng's avatar
Votenet  
wuyuefeng committed
24
                 filter_empty_gt=True,
zhangwenwei's avatar
zhangwenwei committed
25
                 test_mode=False):
26
27
28
29
30
31
32
33
34
        super().__init__(
            data_root=data_root,
            ann_file=ann_file,
            pipeline=pipeline,
            classes=classes,
            modality=modality,
            box_type_3d=box_type_3d,
            filter_empty_gt=filter_empty_gt,
            test_mode=test_mode)
liyinhao's avatar
liyinhao committed
35

liyinhao's avatar
liyinhao committed
36
    def get_ann_info(self, index):
liyinhao's avatar
liyinhao committed
37
        # Use index to get the annos, thus the evalhook could also use this api
liyinhao's avatar
liyinhao committed
38
        info = self.data_infos[index]
liyinhao's avatar
liyinhao committed
39
        if info['annos']['gt_num'] != 0:
liyinhao's avatar
liyinhao committed
40
41
42
            gt_bboxes_3d = info['annos']['gt_boxes_upright_depth'].astype(
                np.float32)  # k, 6
            gt_labels_3d = info['annos']['class'].astype(np.long)
liyinhao's avatar
liyinhao committed
43
        else:
liyinhao's avatar
liyinhao committed
44
            gt_bboxes_3d = np.zeros((0, 7), dtype=np.float32)
liyinhao's avatar
liyinhao committed
45
            gt_labels_3d = np.zeros((0, ), dtype=np.long)
liyinhao's avatar
liyinhao committed
46

wuyuefeng's avatar
wuyuefeng committed
47
48
49
50
        # to target box structure
        gt_bboxes_3d = DepthInstance3DBoxes(
            gt_bboxes_3d, origin=(0.5, 0.5, 0.5)).convert_to(self.box_mode_3d)

liyinhao's avatar
liyinhao committed
51
        anns_results = dict(
liyinhao's avatar
liyinhao committed
52
            gt_bboxes_3d=gt_bboxes_3d, gt_labels_3d=gt_labels_3d)
liyinhao's avatar
liyinhao committed
53
        return anns_results
liyinhao's avatar
liyinhao committed
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71

    def show(self, results, out_dir):
        assert out_dir is not None, 'Expect out_dir, got none.'
        for i, result in enumerate(results):
            data_info = self.data_infos[i]
            pts_path = data_info['pts_path']
            file_name = osp.split(pts_path)[-1].split('.')[0]
            points = np.fromfile(
                osp.join(self.data_root, pts_path),
                dtype=np.float32).reshape(-1, 6)
            points[:, 3:] *= 255
            if data_info['annos']['gt_num'] > 0:
                gt_bboxes = data_info['annos']['gt_boxes_upright_depth']
            else:
                gt_bboxes = np.zeros((0, 7))
            pred_bboxes = result['boxes_3d'].tensor.numpy()
            pred_bboxes[..., 2] += pred_bboxes[..., 5] / 2
            show_result(points, gt_bboxes, pred_bboxes, out_dir, file_name)