Commit 0241806b authored by Jingwei Zhang's avatar Jingwei Zhang Committed by ZwwWayne
Browse files

[Fix] fix instance statistics when only detecting a single class (#2003)

parent bf02d499
...@@ -89,7 +89,6 @@ train_pipeline = [ ...@@ -89,7 +89,6 @@ train_pipeline = [
dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range), dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
dict(type='ObjectNameFilter', classes=class_names), dict(type='ObjectNameFilter', classes=class_names),
dict(type='PointShuffle'), dict(type='PointShuffle'),
dict(type='DefaultFormatBundle3D', class_names=class_names),
dict( dict(
type='Pack3DDetInputs', type='Pack3DDetInputs',
keys=['points', 'gt_bboxes_3d', 'gt_labels_3d']) keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
......
...@@ -255,8 +255,9 @@ class Det3DDataset(BaseDataset): ...@@ -255,8 +255,9 @@ class Det3DDataset(BaseDataset):
ann_info['instances'] = info['instances'] ann_info['instances'] = info['instances']
for label in ann_info['gt_labels_3d']: for label in ann_info['gt_labels_3d']:
cat_name = self.metainfo['classes'][label] if label != -1:
self.num_ins_per_cat[cat_name] += 1 cat_name = self.metainfo['classes'][label]
self.num_ins_per_cat[cat_name] += 1
return ann_info return ann_info
...@@ -336,12 +337,16 @@ class Det3DDataset(BaseDataset): ...@@ -336,12 +337,16 @@ class Det3DDataset(BaseDataset):
""" """
ori_num_per_cat = dict() ori_num_per_cat = dict()
for label in old_labels: for label in old_labels:
cat_name = self.metainfo['classes'][label] if label != -1:
ori_num_per_cat[cat_name] = ori_num_per_cat.get(cat_name, 0) + 1 cat_name = self.metainfo['classes'][label]
ori_num_per_cat[cat_name] = ori_num_per_cat.get(cat_name,
0) + 1
new_num_per_cat = dict() new_num_per_cat = dict()
for label in new_labels: for label in new_labels:
cat_name = self.metainfo['classes'][label] if label != -1:
new_num_per_cat[cat_name] = new_num_per_cat.get(cat_name, 0) + 1 cat_name = self.metainfo['classes'][label]
new_num_per_cat[cat_name] = new_num_per_cat.get(cat_name,
0) + 1
content_show = [['category', 'new number', 'ori number']] content_show = [['category', 'new number', 'ori number']]
for cat_name, num in ori_num_per_cat.items(): for cat_name, num in ori_num_per_cat.items():
new_num = new_num_per_cat.get(cat_name, 0) new_num = new_num_per_cat.get(cat_name, 0)
...@@ -387,9 +392,16 @@ class Det3DDataset(BaseDataset): ...@@ -387,9 +392,16 @@ class Det3DDataset(BaseDataset):
return None return None
if self.show_ins_var: if self.show_ins_var:
self._show_ins_var( if 'ann_info' in ori_input_dict:
ori_input_dict['ann_info']['gt_labels_3d'], self._show_ins_var(
example['data_samples'].gt_instances_3d.labels_3d) ori_input_dict['ann_info']['gt_labels_3d'],
example['data_samples'].gt_instances_3d.labels_3d)
else:
print_log(
"'ann_info' is not in the input dict. It's probably that "
'the data is not in training mode',
'current',
level=30)
return example return example
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment