test_instance_seg_metric.py 3.21 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# Copyright (c) OpenMMLab. All rights reserved.
import unittest

import numpy as np
import torch
from mmengine.data import BaseDataElement

from mmdet3d.core.data_structures import Det3DDataSample, PointData
from mmdet3d.metrics import InstanceSegMetric


class TestInstanceSegMetric(unittest.TestCase):

    def _demo_mm_inputs(self):
        """Create a superset of inputs needed to run test or train batches."""
        packed_inputs = []
        results_dict = dict()
        mm_inputs = dict()
        n_points = 3300
        gt_labels = [0, 0, 0, 0, 0, 0, 14, 14, 2, 1]
        gt_instance_mask = np.ones(n_points, dtype=np.int) * -1
        gt_semantic_mask = np.ones(n_points, dtype=np.int) * -1
        for i, gt_label in enumerate(gt_labels):
            begin = i * 300
            end = begin + 300
            gt_instance_mask[begin:end] = i
            gt_semantic_mask[begin:end] = gt_label

        results_dict['pts_instance_mask'] = torch.tensor(gt_instance_mask)
        results_dict['pts_semantic_mask'] = torch.tensor(gt_semantic_mask)

        data_sample = Det3DDataSample()
        data_sample.gt_pts_seg = PointData(**results_dict)
        mm_inputs['data_sample'] = data_sample.to_dict()
        packed_inputs.append(mm_inputs)

        return packed_inputs

    def _demo_mm_model_output(self):
        """Create a superset of inputs needed to run test or train batches."""
        results_dict = dict()
        n_points = 3300
        gt_labels = [0, 0, 0, 0, 0, 0, 14, 14, 2, 1]
        pred_instance_mask = np.ones(n_points, dtype=np.int) * -1
        labels = []
        scores = []
        for i, gt_label in enumerate(gt_labels):
            begin = i * 300
            end = begin + 300
            pred_instance_mask[begin:end] = i
            labels.append(gt_label)
            scores.append(.99)

        results_dict['pts_instance_mask'] = torch.tensor(pred_instance_mask)
        results_dict['instance_labels'] = torch.tensor(labels)
        results_dict['instance_scores'] = torch.tensor(scores)
        data_sample = Det3DDataSample()
        data_sample.pred_pts_seg = PointData(**results_dict)
        batch_data_samples = [data_sample]

        predictions = []
        for pred in batch_data_samples:
            if isinstance(pred, BaseDataElement):
                pred = pred.to_dict()
            predictions.append(pred)

        return predictions

    def test_evaluate(self):
        data_batch = self._demo_mm_inputs()
        predictions = self._demo_mm_model_output()
        valid_class_ids = (3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33,
                           34, 36, 39)
        class_labels = ('cabinet', 'bed', 'chair', 'sofa', 'table', 'door',
                        'window', 'bookshelf', 'picture', 'counter', 'desk',
                        'curtain', 'refrigerator', 'showercurtrain', 'toilet',
                        'sink', 'bathtub', 'garbagebin')
        dataset_meta = dict(
            VALID_CLASS_IDS=valid_class_ids, CLASSES=class_labels)
        instance_seg_metric = InstanceSegMetric()
        instance_seg_metric.dataset_meta = dataset_meta
        instance_seg_metric.process(data_batch, predictions)
        res = instance_seg_metric.evaluate(6)
        self.assertIsInstance(res, dict)