# Copyright (c) OpenMMLab. All rights reserved. from typing import Dict, Optional, Sequence from mmengine.evaluator import BaseMetric from mmengine.logging import MMLogger from mmdet3d.evaluation import instance_seg_eval from mmdet3d.registry import METRICS @METRICS.register_module() class InstanceSegMetric(BaseMetric): """3D instance segmentation evaluation metric. Args: collect_device (str, optional): Device name used for collecting results from different ranks during distributed training. Must be 'cpu' or 'gpu'. Defaults to 'cpu'. prefix (str): The prefix that will be added in the metric names to disambiguate homonymous metrics of different evaluators. If prefix is not provided in the argument, self.default_prefix will be used instead. Default: None """ def __init__(self, collect_device: str = 'cpu', prefix: Optional[str] = None, **kwargs): super(InstanceSegMetric, self).__init__( prefix=prefix, collect_device=collect_device) def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: """Process one batch of data samples and predictions. The processed results should be stored in ``self.results``, which will be used to compute the metrics when all batches have been processed. Args: data_batch (dict): A batch of data from the dataloader. data_samples (Sequence[dict]): A batch of outputs from the model. """ for data_sample in data_samples: pred_3d = data_sample['pred_pts_seg'] eval_ann_info = data_sample['eval_ann_info'] cpu_pred_3d = dict() for k, v in pred_3d.items(): if hasattr(v, 'to'): cpu_pred_3d[k] = v.to('cpu') else: cpu_pred_3d[k] = v self.results.append((eval_ann_info, cpu_pred_3d)) def compute_metrics(self, results: list) -> Dict[str, float]: """Compute the metrics from processed results. Args: results (list): The processed results of each batch. Returns: Dict[str, float]: The computed metrics. The keys are the names of the metrics, and the values are corresponding results. """ logger: MMLogger = MMLogger.get_current_instance() self.classes = self.dataset_meta['classes'] self.valid_class_ids = self.dataset_meta['seg_valid_class_ids'] gt_semantic_masks = [] gt_instance_masks = [] pred_instance_masks = [] pred_instance_labels = [] pred_instance_scores = [] for eval_ann, sinlge_pred_results in results: gt_semantic_masks.append(eval_ann['pts_semantic_mask']) gt_instance_masks.append(eval_ann['pts_instance_mask']) pred_instance_masks.append( sinlge_pred_results['pts_instance_mask']) pred_instance_labels.append(sinlge_pred_results['instance_labels']) pred_instance_scores.append(sinlge_pred_results['instance_scores']) ret_dict = instance_seg_eval( gt_semantic_masks, gt_instance_masks, pred_instance_masks, pred_instance_labels, pred_instance_scores, valid_class_ids=self.valid_class_ids, class_labels=self.classes, logger=logger) return ret_dict