Commit 92018ce1 authored by liyinhao's avatar liyinhao
Browse files

use self.get_classes()

parent cbb98068
......@@ -91,14 +91,14 @@ def eval_det_cls(pred, gt, ovthresh=None):
for a single class.
Args:
pred (dict): map of {img_id: [(bbox, score)]} where bbox is numpy array
gt (dict): map of {img_id: [bbox]}
ovthresh (List[float]): a list, iou threshold
pred (dict): {img_id: [(bbox, score)]} where bbox is numpy array.
gt (dict): {img_id: [bbox]}.
ovthresh (List[float]): a list, iou threshold.
Return:
ndarray: numpy array of length nd
ndarray: numpy array of length nd
float: scalar, average precision
ndarray: numpy array of length nd.
ndarray: numpy array of length nd.
float: scalar, average precision.
"""
# construct gt objects
......@@ -295,13 +295,10 @@ def indoor_eval(gt_annos, dt_annos, metric, label2cat):
ret_dict = {}
for i, iou_thresh in enumerate(metric):
for label in ap[i].keys():
ret_dict[f'{label2cat[label]}_AP_{int(iou_thresh * 100)}'] = ap[i][
label]
ret_dict[f'mAP_{int(iou_thresh * 100)}'] = sum(ap[i].values()) / len(
ap[i])
ret_dict[f'{label2cat[label]}_AP_{iou_thresh:.2f}'] = ap[i][label]
ret_dict[f'mAP_{iou_thresh:.2f}'] = sum(ap[i].values()) / len(ap[i])
for label in rec[i].keys():
ret_dict[f'{label2cat[label]}_rec_{int(iou_thresh * 100)}'] = rec[
i][label]
ret_dict[f'mAR_{int(iou_thresh * 100)}'] = sum(rec[i].values()) / len(
rec[i])
ret_dict[f'{label2cat[label]}_rec_{iou_thresh:.2f}'] = rec[i][
label]
ret_dict[f'mAR_{iou_thresh:.2f}'] = sum(rec[i].values()) / len(rec[i])
return ret_dict
......@@ -20,7 +20,7 @@ class IndoorBaseDataset(torch_data.Dataset):
with_label=True):
super().__init__()
self.root_path = root_path
self.CLASSES = classes if classes else self.CLASSES
self.CLASSES = self.get_classes(classes)
self.test_mode = test_mode
self.label2cat = {i: cat_id for i, cat_id in enumerate(self.CLASSES)}
mmcv.check_file_exist(ann_file)
......@@ -77,6 +77,29 @@ class IndoorBaseDataset(torch_data.Dataset):
example = self.pipeline(input_dict)
return example
@classmethod
def get_classes(cls, classes=None):
"""Get class names of current dataset
Args:
classes (Sequence[str] | str | None): If classes is None, use
default CLASSES defined by builtin dataset. If classes is a
string, take it as a file name. The file contains the name of
classes where each line contains one class name. If classes is
a tuple or list, override the CLASSES defined by the dataset.
"""
if classes is None:
return cls.CLASSES
if isinstance(classes, str):
# take it as a file path
class_names = mmcv.list_from_file(classes)
elif isinstance(classes, (tuple, list)):
class_names = classes
else:
raise ValueError(f'Unsupported type {type(classes)} of classes.')
return class_names
def _generate_annotations(self, output):
"""Generate Annotations.
......
......@@ -109,10 +109,10 @@ def test_evaluate():
results.append([pred_boxes])
metric = [0.25, 0.5]
ret_dict = scannet_dataset.evaluate(results, metric)
table_average_precision_25 = ret_dict['table_AP_25']
window_average_precision_25 = ret_dict['window_AP_25']
counter_average_precision_25 = ret_dict['counter_AP_25']
curtain_average_precision_25 = ret_dict['curtain_AP_25']
table_average_precision_25 = ret_dict['table_AP_0.25']
window_average_precision_25 = ret_dict['window_AP_0.25']
counter_average_precision_25 = ret_dict['counter_AP_0.25']
curtain_average_precision_25 = ret_dict['curtain_AP_0.25']
assert abs(table_average_precision_25 - 0.3333) < 0.01
assert abs(window_average_precision_25 - 1) < 0.01
assert abs(counter_average_precision_25 - 1) < 0.01
......
......@@ -85,9 +85,9 @@ def test_evaluate():
results.append([pred_boxes])
metric = [0.25, 0.5]
ap_dict = sunrgbd_dataset.evaluate(results, metric)
bed_precision_25 = ap_dict['bed_AP_25']
dresser_precision_25 = ap_dict['dresser_AP_25']
night_stand_precision_25 = ap_dict['night_stand_AP_25']
bed_precision_25 = ap_dict['bed_AP_0.25']
dresser_precision_25 = ap_dict['dresser_AP_0.25']
night_stand_precision_25 = ap_dict['night_stand_AP_0.25']
assert abs(bed_precision_25 - 1) < 0.01
assert abs(dresser_precision_25 - 1) < 0.01
assert abs(night_stand_precision_25 - 1) < 0.01
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment