Commit 1071dab5 authored by Jingwei Zhang's avatar Jingwei Zhang Committed by ZwwWayne
Browse files

[Enhance] Show instance statistics before and after through pipeline (#1863)

* add instance statistics before and after through pipeline

* add docstring

* support showing cat-wise instance statistics

* show all statistics of the dataset

* small fix

* polish code

* show table

* small fix

* rename some varibles
parent 3960f9a7
......@@ -5,7 +5,10 @@ from typing import Callable, List, Optional, Union
import mmengine
import numpy as np
import torch
from mmengine.dataset import BaseDataset
from mmengine.logging import print_log
from terminaltables import AsciiTable
from mmdet3d.datasets import DATASETS
from mmdet3d.structures import get_box_type
......@@ -58,6 +61,9 @@ class Det3DDataset(BaseDataset):
which can be used in Evaluator. Defaults to True.
file_client_args (dict, optional): Configuration of file client.
Defaults to dict(backend='disk').
show_ins_var (bool, optional): For debug purpose. Whether to show
variation of the number of instances before and after through
pipeline. Defaults to False.
"""
def __init__(self,
......@@ -73,6 +79,7 @@ class Det3DDataset(BaseDataset):
test_mode: bool = False,
load_eval_anns=True,
file_client_args: dict = dict(backend='disk'),
show_ins_var: bool = False,
**kwargs) -> None:
# init file client
self.file_client = mmengine.FileClient(**file_client_args)
......@@ -105,6 +112,8 @@ class Det3DDataset(BaseDataset):
for label_idx, name in enumerate(metainfo['CLASSES']):
ori_label = self.METAINFO['CLASSES'].index(name)
self.label_mapping[ori_label] = label_idx
self.num_ins_per_cat = {name: 0 for name in metainfo['CLASSES']}
else:
self.label_mapping = {
i: i
......@@ -112,6 +121,11 @@ class Det3DDataset(BaseDataset):
}
self.label_mapping[-1] = -1
self.num_ins_per_cat = {
name: 0
for name in self.METAINFO['CLASSES']
}
super().__init__(
ann_file=ann_file,
metainfo=metainfo,
......@@ -125,7 +139,22 @@ class Det3DDataset(BaseDataset):
self.metainfo['box_type_3d'] = box_type_3d
self.metainfo['label_mapping'] = self.label_mapping
def _remove_dontcare(self, ann_info: dict) -> dict:
# used for showing variation of the number of instances before and
# after through the pipeline
self.show_ins_var = show_ins_var
# show statistics of this dataset
print_log('-' * 30, 'current')
print_log(f'The length of the dataset: {len(self)}', 'current')
content_show = [['category', 'number']]
for cat_name, num in self.num_ins_per_cat.items():
content_show.append([cat_name, num])
table = AsciiTable(content_show)
print_log(
f'The number of instances per category in the dataset:\n{table.table}', # noqa: E501
'current')
def _remove_dontcare(self, ann_info):
"""Remove annotations that do not need to be cared.
-1 indicate dontcare in MMDet3d.
......@@ -223,6 +252,11 @@ class Det3DDataset(BaseDataset):
ann_info[mapped_ann_name] = temp_anns
ann_info['instances'] = info['instances']
for label in ann_info['gt_labels_3d']:
cat_name = self.metainfo['CLASSES'][label]
self.num_ins_per_cat[cat_name] += 1
return ann_info
def parse_data_info(self, info: dict) -> dict:
......@@ -291,6 +325,31 @@ class Det3DDataset(BaseDataset):
return info
def _show_ins_var(self, old_labels: np.ndarray, new_labels: torch.Tensor):
"""Show variation of the number of instances before and after through
the pipeline.
Args:
old_labels (np.ndarray): The labels before through the pipeline.
new_labels (torch.Tensor): The labels after through the pipeline.
"""
ori_num_per_cat = dict()
for label in old_labels:
cat_name = self.metainfo['CLASSES'][label]
ori_num_per_cat[cat_name] = ori_num_per_cat.get(cat_name, 0) + 1
new_num_per_cat = dict()
for label in new_labels:
cat_name = self.metainfo['CLASSES'][label]
new_num_per_cat[cat_name] = new_num_per_cat.get(cat_name, 0) + 1
content_show = [['category', 'new number', 'ori number']]
for cat_name, num in ori_num_per_cat.items():
new_num = new_num_per_cat.get(cat_name, 0)
content_show.append([cat_name, new_num, num])
table = AsciiTable(content_show)
print_log(
'The number of instances per category after and before '
f'through pipeline:\n{table.table}', 'current')
def prepare_data(self, index: int) -> Optional[dict]:
"""Data preparation for both training and testing stage.
......@@ -302,10 +361,10 @@ class Det3DDataset(BaseDataset):
Returns:
dict | None: Data dict of the corresponding index.
"""
input_dict = self.get_data_info(index)
ori_input_dict = self.get_data_info(index)
# deepcopy here to avoid inplace modification in pipeline.
input_dict = copy.deepcopy(input_dict)
input_dict = copy.deepcopy(ori_input_dict)
# box_type_3d (str): 3D box type.
input_dict['box_type_3d'] = self.box_type_3d
......@@ -318,12 +377,19 @@ class Det3DDataset(BaseDataset):
return None
example = self.pipeline(input_dict)
if not self.test_mode and self.filter_empty_gt:
# after pipeline drop the example with empty annotations
# return None to random another in `__getitem__`
if example is None or len(
example['data_samples'].gt_instances_3d.labels_3d) == 0:
return None
if self.show_ins_var:
self._show_ins_var(
ori_input_dict['ann_info']['gt_labels_3d'],
example['data_samples'].gt_instances_3d.labels_3d)
return example
def get_cat_ids(self, idx: int) -> List[int]:
......
......@@ -133,12 +133,12 @@ class DataBaseSampler(object):
from mmengine.logging import MMLogger
logger: MMLogger = MMLogger.get_current_instance()
for k, v in db_infos.items():
logger.info(f'load {len(v)} {k} database infos')
logger.info(f'load {len(v)} {k} database infos in DataBaseSampler')
for prep_func, val in prepare.items():
db_infos = getattr(self, prep_func)(db_infos, val)
logger.info('After filter database:')
for k, v in db_infos.items():
logger.info(f'load {len(v)} {k} database infos')
logger.info(f'load {len(v)} {k} database infos in DataBaseSampler')
self.db_infos = db_infos
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment