instance_seg_metric.py 3.47 KB
Newer Older
1
2
3
4
5
6
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional, Sequence

from mmengine.evaluator import BaseMetric
from mmengine.logging import MMLogger

zhangshilong's avatar
zhangshilong committed
7
from mmdet3d.evaluation import instance_seg_eval
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from mmdet3d.registry import METRICS


@METRICS.register_module()
class InstanceSegMetric(BaseMetric):
    """3D instance segmentation evaluation metric.

    Args:
        collect_device (str, optional): Device name used for collecting
            results from different ranks during distributed training.
            Must be 'cpu' or 'gpu'. Defaults to 'cpu'.
        prefix (str): The prefix that will be added in the metric
            names to disambiguate homonymous metrics of different evaluators.
            If prefix is not provided in the argument, self.default_prefix
            will be used instead. Default: None
    """

    def __init__(self,
                 collect_device: str = 'cpu',
                 prefix: Optional[str] = None,
                 **kwargs):
        super(InstanceSegMetric, self).__init__(
            prefix=prefix, collect_device=collect_device)

32
    def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
33
34
35
36
37
38
39
        """Process one batch of data samples and predictions.

        The processed results should be stored in ``self.results``,
        which will be used to compute the metrics when all batches
        have been processed.

        Args:
40
41
            data_batch (dict): A batch of data from the dataloader.
            data_samples (Sequence[dict]): A batch of outputs from
42
43
                the model.
        """
44
45
46
        for data_sample in data_samples:
            pred_3d = data_sample['pred_pts_seg']
            eval_ann_info = data_sample['eval_ann_info']
47
48
49
50
51
52
            cpu_pred_3d = dict()
            for k, v in pred_3d.items():
                if hasattr(v, 'to'):
                    cpu_pred_3d[k] = v.to('cpu')
                else:
                    cpu_pred_3d[k] = v
53
            self.results.append((eval_ann_info, cpu_pred_3d))
54
55
56
57
58
59
60
61
62
63
64
65
66
67

    def compute_metrics(self, results: list) -> Dict[str, float]:
        """Compute the metrics from processed results.

        Args:
            results (list): The processed results of each batch.

        Returns:
            Dict[str, float]: The computed metrics. The keys are the names of
            the metrics, and the values are corresponding results.
        """
        logger: MMLogger = MMLogger.get_current_instance()

        self.classes = self.dataset_meta['CLASSES']
68
        self.valid_class_ids = self.dataset_meta['seg_valid_class_ids']
69
70
71
72
73
74
75
76
77
78

        gt_semantic_masks = []
        gt_instance_masks = []
        pred_instance_masks = []
        pred_instance_labels = []
        pred_instance_scores = []

        for eval_ann, sinlge_pred_results in results:
            gt_semantic_masks.append(eval_ann['pts_semantic_mask'])
            gt_instance_masks.append(eval_ann['pts_instance_mask'])
79
80
81
82
            pred_instance_masks.append(
                sinlge_pred_results['pts_instance_mask'])
            pred_instance_labels.append(sinlge_pred_results['instance_labels'])
            pred_instance_scores.append(sinlge_pred_results['instance_scores'])
83
84
85
86
87
88
89
90
91
92
93
94

        ret_dict = instance_seg_eval(
            gt_semantic_masks,
            gt_instance_masks,
            pred_instance_masks,
            pred_instance_labels,
            pred_instance_scores,
            valid_class_ids=self.valid_class_ids,
            class_labels=self.classes,
            logger=logger)

        return ret_dict