"git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "5a533e79b4635c2b62176d415a8ae2a38dab46f1"
Commit 92018ce1 authored by liyinhao's avatar liyinhao
Browse files

use self.get_classes()

parent cbb98068
...@@ -91,14 +91,14 @@ def eval_det_cls(pred, gt, ovthresh=None): ...@@ -91,14 +91,14 @@ def eval_det_cls(pred, gt, ovthresh=None):
for a single class. for a single class.
Args: Args:
pred (dict): map of {img_id: [(bbox, score)]} where bbox is numpy array pred (dict): {img_id: [(bbox, score)]} where bbox is numpy array.
gt (dict): map of {img_id: [bbox]} gt (dict): {img_id: [bbox]}.
ovthresh (List[float]): a list, iou threshold ovthresh (List[float]): a list, iou threshold.
Return: Return:
ndarray: numpy array of length nd ndarray: numpy array of length nd.
ndarray: numpy array of length nd ndarray: numpy array of length nd.
float: scalar, average precision float: scalar, average precision.
""" """
# construct gt objects # construct gt objects
...@@ -295,13 +295,10 @@ def indoor_eval(gt_annos, dt_annos, metric, label2cat): ...@@ -295,13 +295,10 @@ def indoor_eval(gt_annos, dt_annos, metric, label2cat):
ret_dict = {} ret_dict = {}
for i, iou_thresh in enumerate(metric): for i, iou_thresh in enumerate(metric):
for label in ap[i].keys(): for label in ap[i].keys():
ret_dict[f'{label2cat[label]}_AP_{int(iou_thresh * 100)}'] = ap[i][ ret_dict[f'{label2cat[label]}_AP_{iou_thresh:.2f}'] = ap[i][label]
label] ret_dict[f'mAP_{iou_thresh:.2f}'] = sum(ap[i].values()) / len(ap[i])
ret_dict[f'mAP_{int(iou_thresh * 100)}'] = sum(ap[i].values()) / len(
ap[i])
for label in rec[i].keys(): for label in rec[i].keys():
ret_dict[f'{label2cat[label]}_rec_{int(iou_thresh * 100)}'] = rec[ ret_dict[f'{label2cat[label]}_rec_{iou_thresh:.2f}'] = rec[i][
i][label] label]
ret_dict[f'mAR_{int(iou_thresh * 100)}'] = sum(rec[i].values()) / len( ret_dict[f'mAR_{iou_thresh:.2f}'] = sum(rec[i].values()) / len(rec[i])
rec[i])
return ret_dict return ret_dict
...@@ -20,7 +20,7 @@ class IndoorBaseDataset(torch_data.Dataset): ...@@ -20,7 +20,7 @@ class IndoorBaseDataset(torch_data.Dataset):
with_label=True): with_label=True):
super().__init__() super().__init__()
self.root_path = root_path self.root_path = root_path
self.CLASSES = classes if classes else self.CLASSES self.CLASSES = self.get_classes(classes)
self.test_mode = test_mode self.test_mode = test_mode
self.label2cat = {i: cat_id for i, cat_id in enumerate(self.CLASSES)} self.label2cat = {i: cat_id for i, cat_id in enumerate(self.CLASSES)}
mmcv.check_file_exist(ann_file) mmcv.check_file_exist(ann_file)
...@@ -77,6 +77,29 @@ class IndoorBaseDataset(torch_data.Dataset): ...@@ -77,6 +77,29 @@ class IndoorBaseDataset(torch_data.Dataset):
example = self.pipeline(input_dict) example = self.pipeline(input_dict)
return example return example
@classmethod
def get_classes(cls, classes=None):
"""Get class names of current dataset
Args:
classes (Sequence[str] | str | None): If classes is None, use
default CLASSES defined by builtin dataset. If classes is a
string, take it as a file name. The file contains the name of
classes where each line contains one class name. If classes is
a tuple or list, override the CLASSES defined by the dataset.
"""
if classes is None:
return cls.CLASSES
if isinstance(classes, str):
# take it as a file path
class_names = mmcv.list_from_file(classes)
elif isinstance(classes, (tuple, list)):
class_names = classes
else:
raise ValueError(f'Unsupported type {type(classes)} of classes.')
return class_names
def _generate_annotations(self, output): def _generate_annotations(self, output):
"""Generate Annotations. """Generate Annotations.
......
...@@ -109,10 +109,10 @@ def test_evaluate(): ...@@ -109,10 +109,10 @@ def test_evaluate():
results.append([pred_boxes]) results.append([pred_boxes])
metric = [0.25, 0.5] metric = [0.25, 0.5]
ret_dict = scannet_dataset.evaluate(results, metric) ret_dict = scannet_dataset.evaluate(results, metric)
table_average_precision_25 = ret_dict['table_AP_25'] table_average_precision_25 = ret_dict['table_AP_0.25']
window_average_precision_25 = ret_dict['window_AP_25'] window_average_precision_25 = ret_dict['window_AP_0.25']
counter_average_precision_25 = ret_dict['counter_AP_25'] counter_average_precision_25 = ret_dict['counter_AP_0.25']
curtain_average_precision_25 = ret_dict['curtain_AP_25'] curtain_average_precision_25 = ret_dict['curtain_AP_0.25']
assert abs(table_average_precision_25 - 0.3333) < 0.01 assert abs(table_average_precision_25 - 0.3333) < 0.01
assert abs(window_average_precision_25 - 1) < 0.01 assert abs(window_average_precision_25 - 1) < 0.01
assert abs(counter_average_precision_25 - 1) < 0.01 assert abs(counter_average_precision_25 - 1) < 0.01
......
...@@ -85,9 +85,9 @@ def test_evaluate(): ...@@ -85,9 +85,9 @@ def test_evaluate():
results.append([pred_boxes]) results.append([pred_boxes])
metric = [0.25, 0.5] metric = [0.25, 0.5]
ap_dict = sunrgbd_dataset.evaluate(results, metric) ap_dict = sunrgbd_dataset.evaluate(results, metric)
bed_precision_25 = ap_dict['bed_AP_25'] bed_precision_25 = ap_dict['bed_AP_0.25']
dresser_precision_25 = ap_dict['dresser_AP_25'] dresser_precision_25 = ap_dict['dresser_AP_0.25']
night_stand_precision_25 = ap_dict['night_stand_AP_25'] night_stand_precision_25 = ap_dict['night_stand_AP_0.25']
assert abs(bed_precision_25 - 1) < 0.01 assert abs(bed_precision_25 - 1) < 0.01
assert abs(dresser_precision_25 - 1) < 0.01 assert abs(dresser_precision_25 - 1) < 0.01
assert abs(night_stand_precision_25 - 1) < 0.01 assert abs(night_stand_precision_25 - 1) < 0.01
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment