instance_seg_metric.py 3.43 KB
Newer Older
1
2
3
4
5
6
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional, Sequence

from mmengine.evaluator import BaseMetric
from mmengine.logging import MMLogger

zhangshilong's avatar
zhangshilong committed
7
from mmdet3d.evaluation import instance_seg_eval
8
9
10
11
12
13
14
15
from mmdet3d.registry import METRICS


@METRICS.register_module()
class InstanceSegMetric(BaseMetric):
    """3D instance segmentation evaluation metric.

    Args:
16
17
18
19
        collect_device (str): Device name used for collecting results from
            different ranks during distributed training. Must be 'cpu' or
            'gpu'. Defaults to 'cpu'.
        prefix (str, optional): The prefix that will be added in the metric
20
            names to disambiguate homonymous metrics of different evaluators.
21
22
            If prefix is not provided in the argument, self.default_prefix will
            be used instead. Defaults to None.
23
24
25
26
    """

    def __init__(self,
                 collect_device: str = 'cpu',
27
                 prefix: Optional[str] = None):
28
29
30
        super(InstanceSegMetric, self).__init__(
            prefix=prefix, collect_device=collect_device)

31
    def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
32
33
        """Process one batch of data samples and predictions.

34
35
        The processed results should be stored in ``self.results``, which will
        be used to compute the metrics when all batches have been processed.
36
37

        Args:
38
            data_batch (dict): A batch of data from the dataloader.
39
            data_samples (Sequence[dict]): A batch of outputs from the model.
40
        """
41
42
43
        for data_sample in data_samples:
            pred_3d = data_sample['pred_pts_seg']
            eval_ann_info = data_sample['eval_ann_info']
44
45
46
47
48
49
            cpu_pred_3d = dict()
            for k, v in pred_3d.items():
                if hasattr(v, 'to'):
                    cpu_pred_3d[k] = v.to('cpu')
                else:
                    cpu_pred_3d[k] = v
50
            self.results.append((eval_ann_info, cpu_pred_3d))
51
52
53
54
55
56
57
58
59
60
61
62
63

    def compute_metrics(self, results: list) -> Dict[str, float]:
        """Compute the metrics from processed results.

        Args:
            results (list): The processed results of each batch.

        Returns:
            Dict[str, float]: The computed metrics. The keys are the names of
            the metrics, and the values are corresponding results.
        """
        logger: MMLogger = MMLogger.get_current_instance()

64
        self.classes = self.dataset_meta['classes']
65
        self.valid_class_ids = self.dataset_meta['seg_valid_class_ids']
66
67
68
69
70
71
72
73
74
75

        gt_semantic_masks = []
        gt_instance_masks = []
        pred_instance_masks = []
        pred_instance_labels = []
        pred_instance_scores = []

        for eval_ann, sinlge_pred_results in results:
            gt_semantic_masks.append(eval_ann['pts_semantic_mask'])
            gt_instance_masks.append(eval_ann['pts_instance_mask'])
76
77
78
79
            pred_instance_masks.append(
                sinlge_pred_results['pts_instance_mask'])
            pred_instance_labels.append(sinlge_pred_results['instance_labels'])
            pred_instance_scores.append(sinlge_pred_results['instance_scores'])
80
81
82
83
84
85
86
87
88
89
90
91

        ret_dict = instance_seg_eval(
            gt_semantic_masks,
            gt_instance_masks,
            pred_instance_masks,
            pred_instance_labels,
            pred_instance_scores,
            valid_class_ids=self.valid_class_ids,
            class_labels=self.classes,
            logger=logger)

        return ret_dict