test_seg_metric.py 1.65 KB
Newer Older
1
2
3
4
# Copyright (c) OpenMMLab. All rights reserved.
import unittest

import torch
5
from mmengine.structures import BaseDataElement
6

zhangshilong's avatar
zhangshilong committed
7
8
from mmdet3d.evaluation.metrics import SegMetric
from mmdet3d.structures import Det3DDataSample, PointData
9
10
11
12
13
14


class TestSegMetric(unittest.TestCase):

    def _demo_mm_model_output(self):
        """Create a superset of inputs needed to run test or train batches."""
15
        pred_pts_semantic_mask = torch.LongTensor([
16
17
            0, 0, 1, 0, 0, 2, 1, 3, 1, 2, 1, 0, 2, 2, 2, 2, 1, 3, 0, 3, 3, 3, 3
        ])
18
        pred_pts_seg_data = dict(pts_semantic_mask=pred_pts_semantic_mask)
19
        data_sample = Det3DDataSample()
20
        data_sample.pred_pts_seg = PointData(**pred_pts_seg_data)
21

22
23
24
25
26
        gt_pts_semantic_mask = torch.LongTensor(([
            0, 0, 0, 4, 0, 0, 1, 1, 1, 4, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4
        ]))
        gt_pts_seg_data = dict(pts_semantic_mask=gt_pts_semantic_mask)
        data_sample.gt_pts_seg = PointData(**gt_pts_seg_data)
27

28
29
30
31
32
33
34
35
36
37
38
        batch_data_samples = [data_sample]

        predictions = []
        for pred in batch_data_samples:
            if isinstance(pred, BaseDataElement):
                pred = pred.to_dict()
            predictions.append(pred)

        return predictions

    def test_evaluate(self):
39
        data_batch = {}
40
        predictions = self._demo_mm_model_output()
41
42
        dataset_meta = dict(classes=('car', 'bicyle', 'motorcycle', 'truck'))
        seg_metric = SegMetric(ignore_index=len(dataset_meta['classes']))
43
44
        seg_metric.dataset_meta = dataset_meta
        seg_metric.process(data_batch, predictions)
45
        res = seg_metric.evaluate(1)
46
        self.assertIsInstance(res, dict)