test_seg_metric.py 1.74 KB
Newer Older
1
2
3
# Copyright (c) OpenMMLab. All rights reserved.
import unittest

4
import numpy as np
5
import torch
6
from mmengine.structures import BaseDataElement
7

zhangshilong's avatar
zhangshilong committed
8
9
from mmdet3d.evaluation.metrics import SegMetric
from mmdet3d.structures import Det3DDataSample, PointData
10
11
12
13
14
15


class TestSegMetric(unittest.TestCase):

    def _demo_mm_model_output(self):
        """Create a superset of inputs needed to run test or train batches."""
16
        pred_pts_semantic_mask = torch.Tensor([
17
18
            0, 0, 1, 0, 0, 2, 1, 3, 1, 2, 1, 0, 2, 2, 2, 2, 1, 3, 0, 3, 3, 3, 3
        ])
19
        pred_pts_seg_data = dict(pts_semantic_mask=pred_pts_semantic_mask)
20
        data_sample = Det3DDataSample()
21
        data_sample.pred_pts_seg = PointData(**pred_pts_seg_data)
22

23
24
25
26
27
28
        gt_pts_semantic_mask = np.array([
            0, 0, 0, 255, 0, 0, 1, 1, 1, 255, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3,
            3, 255
        ])
        ann_info_data = dict(pts_semantic_mask=gt_pts_semantic_mask)
        data_sample.eval_ann_info = ann_info_data
29

30
31
32
33
34
35
36
37
38
39
40
        batch_data_samples = [data_sample]

        predictions = []
        for pred in batch_data_samples:
            if isinstance(pred, BaseDataElement):
                pred = pred.to_dict()
            predictions.append(pred)

        return predictions

    def test_evaluate(self):
41
        data_batch = {}
42
        predictions = self._demo_mm_model_output()
43
44
45
46
47
48
49
50
        label2cat = {
            0: 'car',
            1: 'bicycle',
            2: 'motorcycle',
            3: 'truck',
        }
        dataset_meta = dict(label2cat=label2cat, ignore_index=255)
        seg_metric = SegMetric()
51
52
        seg_metric.dataset_meta = dataset_meta
        seg_metric.process(data_batch, predictions)
53
        res = seg_metric.evaluate(1)
54
        self.assertIsInstance(res, dict)