Commit 1071dab5 authored by Jingwei Zhang's avatar Jingwei Zhang Committed by ZwwWayne
Browse files

[Enhance] Show instance statistics before and after through pipeline (#1863)

* add instance statistics before and after through pipeline

* add docstring

* support showing cat-wise instance statistics

* show all statistics of the dataset

* small fix

* polish code

* show table

* small fix

* rename some varibles
parent 3960f9a7
...@@ -5,7 +5,10 @@ from typing import Callable, List, Optional, Union ...@@ -5,7 +5,10 @@ from typing import Callable, List, Optional, Union
import mmengine import mmengine
import numpy as np import numpy as np
import torch
from mmengine.dataset import BaseDataset from mmengine.dataset import BaseDataset
from mmengine.logging import print_log
from terminaltables import AsciiTable
from mmdet3d.datasets import DATASETS from mmdet3d.datasets import DATASETS
from mmdet3d.structures import get_box_type from mmdet3d.structures import get_box_type
...@@ -58,6 +61,9 @@ class Det3DDataset(BaseDataset): ...@@ -58,6 +61,9 @@ class Det3DDataset(BaseDataset):
which can be used in Evaluator. Defaults to True. which can be used in Evaluator. Defaults to True.
file_client_args (dict, optional): Configuration of file client. file_client_args (dict, optional): Configuration of file client.
Defaults to dict(backend='disk'). Defaults to dict(backend='disk').
show_ins_var (bool, optional): For debug purpose. Whether to show
variation of the number of instances before and after through
pipeline. Defaults to False.
""" """
def __init__(self, def __init__(self,
...@@ -73,6 +79,7 @@ class Det3DDataset(BaseDataset): ...@@ -73,6 +79,7 @@ class Det3DDataset(BaseDataset):
test_mode: bool = False, test_mode: bool = False,
load_eval_anns=True, load_eval_anns=True,
file_client_args: dict = dict(backend='disk'), file_client_args: dict = dict(backend='disk'),
show_ins_var: bool = False,
**kwargs) -> None: **kwargs) -> None:
# init file client # init file client
self.file_client = mmengine.FileClient(**file_client_args) self.file_client = mmengine.FileClient(**file_client_args)
...@@ -105,6 +112,8 @@ class Det3DDataset(BaseDataset): ...@@ -105,6 +112,8 @@ class Det3DDataset(BaseDataset):
for label_idx, name in enumerate(metainfo['CLASSES']): for label_idx, name in enumerate(metainfo['CLASSES']):
ori_label = self.METAINFO['CLASSES'].index(name) ori_label = self.METAINFO['CLASSES'].index(name)
self.label_mapping[ori_label] = label_idx self.label_mapping[ori_label] = label_idx
self.num_ins_per_cat = {name: 0 for name in metainfo['CLASSES']}
else: else:
self.label_mapping = { self.label_mapping = {
i: i i: i
...@@ -112,6 +121,11 @@ class Det3DDataset(BaseDataset): ...@@ -112,6 +121,11 @@ class Det3DDataset(BaseDataset):
} }
self.label_mapping[-1] = -1 self.label_mapping[-1] = -1
self.num_ins_per_cat = {
name: 0
for name in self.METAINFO['CLASSES']
}
super().__init__( super().__init__(
ann_file=ann_file, ann_file=ann_file,
metainfo=metainfo, metainfo=metainfo,
...@@ -125,7 +139,22 @@ class Det3DDataset(BaseDataset): ...@@ -125,7 +139,22 @@ class Det3DDataset(BaseDataset):
self.metainfo['box_type_3d'] = box_type_3d self.metainfo['box_type_3d'] = box_type_3d
self.metainfo['label_mapping'] = self.label_mapping self.metainfo['label_mapping'] = self.label_mapping
def _remove_dontcare(self, ann_info: dict) -> dict: # used for showing variation of the number of instances before and
# after through the pipeline
self.show_ins_var = show_ins_var
# show statistics of this dataset
print_log('-' * 30, 'current')
print_log(f'The length of the dataset: {len(self)}', 'current')
content_show = [['category', 'number']]
for cat_name, num in self.num_ins_per_cat.items():
content_show.append([cat_name, num])
table = AsciiTable(content_show)
print_log(
f'The number of instances per category in the dataset:\n{table.table}', # noqa: E501
'current')
def _remove_dontcare(self, ann_info):
"""Remove annotations that do not need to be cared. """Remove annotations that do not need to be cared.
-1 indicate dontcare in MMDet3d. -1 indicate dontcare in MMDet3d.
...@@ -223,6 +252,11 @@ class Det3DDataset(BaseDataset): ...@@ -223,6 +252,11 @@ class Det3DDataset(BaseDataset):
ann_info[mapped_ann_name] = temp_anns ann_info[mapped_ann_name] = temp_anns
ann_info['instances'] = info['instances'] ann_info['instances'] = info['instances']
for label in ann_info['gt_labels_3d']:
cat_name = self.metainfo['CLASSES'][label]
self.num_ins_per_cat[cat_name] += 1
return ann_info return ann_info
def parse_data_info(self, info: dict) -> dict: def parse_data_info(self, info: dict) -> dict:
...@@ -291,6 +325,31 @@ class Det3DDataset(BaseDataset): ...@@ -291,6 +325,31 @@ class Det3DDataset(BaseDataset):
return info return info
def _show_ins_var(self, old_labels: np.ndarray, new_labels: torch.Tensor):
"""Show variation of the number of instances before and after through
the pipeline.
Args:
old_labels (np.ndarray): The labels before through the pipeline.
new_labels (torch.Tensor): The labels after through the pipeline.
"""
ori_num_per_cat = dict()
for label in old_labels:
cat_name = self.metainfo['CLASSES'][label]
ori_num_per_cat[cat_name] = ori_num_per_cat.get(cat_name, 0) + 1
new_num_per_cat = dict()
for label in new_labels:
cat_name = self.metainfo['CLASSES'][label]
new_num_per_cat[cat_name] = new_num_per_cat.get(cat_name, 0) + 1
content_show = [['category', 'new number', 'ori number']]
for cat_name, num in ori_num_per_cat.items():
new_num = new_num_per_cat.get(cat_name, 0)
content_show.append([cat_name, new_num, num])
table = AsciiTable(content_show)
print_log(
'The number of instances per category after and before '
f'through pipeline:\n{table.table}', 'current')
def prepare_data(self, index: int) -> Optional[dict]: def prepare_data(self, index: int) -> Optional[dict]:
"""Data preparation for both training and testing stage. """Data preparation for both training and testing stage.
...@@ -302,10 +361,10 @@ class Det3DDataset(BaseDataset): ...@@ -302,10 +361,10 @@ class Det3DDataset(BaseDataset):
Returns: Returns:
dict | None: Data dict of the corresponding index. dict | None: Data dict of the corresponding index.
""" """
input_dict = self.get_data_info(index) ori_input_dict = self.get_data_info(index)
# deepcopy here to avoid inplace modification in pipeline. # deepcopy here to avoid inplace modification in pipeline.
input_dict = copy.deepcopy(input_dict) input_dict = copy.deepcopy(ori_input_dict)
# box_type_3d (str): 3D box type. # box_type_3d (str): 3D box type.
input_dict['box_type_3d'] = self.box_type_3d input_dict['box_type_3d'] = self.box_type_3d
...@@ -318,12 +377,19 @@ class Det3DDataset(BaseDataset): ...@@ -318,12 +377,19 @@ class Det3DDataset(BaseDataset):
return None return None
example = self.pipeline(input_dict) example = self.pipeline(input_dict)
if not self.test_mode and self.filter_empty_gt: if not self.test_mode and self.filter_empty_gt:
# after pipeline drop the example with empty annotations # after pipeline drop the example with empty annotations
# return None to random another in `__getitem__` # return None to random another in `__getitem__`
if example is None or len( if example is None or len(
example['data_samples'].gt_instances_3d.labels_3d) == 0: example['data_samples'].gt_instances_3d.labels_3d) == 0:
return None return None
if self.show_ins_var:
self._show_ins_var(
ori_input_dict['ann_info']['gt_labels_3d'],
example['data_samples'].gt_instances_3d.labels_3d)
return example return example
def get_cat_ids(self, idx: int) -> List[int]: def get_cat_ids(self, idx: int) -> List[int]:
......
...@@ -133,12 +133,12 @@ class DataBaseSampler(object): ...@@ -133,12 +133,12 @@ class DataBaseSampler(object):
from mmengine.logging import MMLogger from mmengine.logging import MMLogger
logger: MMLogger = MMLogger.get_current_instance() logger: MMLogger = MMLogger.get_current_instance()
for k, v in db_infos.items(): for k, v in db_infos.items():
logger.info(f'load {len(v)} {k} database infos') logger.info(f'load {len(v)} {k} database infos in DataBaseSampler')
for prep_func, val in prepare.items(): for prep_func, val in prepare.items():
db_infos = getattr(self, prep_func)(db_infos, val) db_infos = getattr(self, prep_func)(db_infos, val)
logger.info('After filter database:') logger.info('After filter database:')
for k, v in db_infos.items(): for k, v in db_infos.items():
logger.info(f'load {len(v)} {k} database infos') logger.info(f'load {len(v)} {k} database infos in DataBaseSampler')
self.db_infos = db_infos self.db_infos = db_infos
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment