Commit df186989 authored by yinchimaoliang's avatar yinchimaoliang
Browse files

finish scannet_dataset

parent d71edf6c
import numpy as np
import torch
from mmdet3d.ops.iou3d import iou3d_cuda
from mmdet3d.core.bbox.iou_calculators.iou3d_calculator import bbox_overlaps_3d
def voc_ap(rec, prec, use_07_metric=False):
""" ap = voc_ap(rec, prec, [use_07_metric])
""" Voc AP
Compute VOC AP given precision and recall.
If use_07_metric is true, uses the
VOC 07 11 point method (default:False).
Args:
rec (ndarray): Recall.
prec (ndarray): Precision.
use_07_metric (bool): Whether to use 07 metric.
Default: False.
Returns:
ap (float): VOC AP.
"""
if use_07_metric:
# 11 point metric
......@@ -39,10 +47,15 @@ def voc_ap(rec, prec, use_07_metric=False):
def boxes3d_to_bevboxes_lidar_torch(boxes3d):
"""
:param boxes3d: (N, 7) [x, y, z, w, l, h, ry] in LiDAR coords
:return:
boxes_bev: (N, 5) [x1, y1, x2, y2, ry]
"""Boxes3d to Bevboxes Lidar.
Transform 3d boxes to bev boxes.
Args:
boxes3d (tensor): [x, y, z, w, l, h, ry] in LiDAR coords.
Returns:
boxes_bev (tensor): [x1, y1, x2, y2, ry].
"""
boxes_bev = boxes3d.new(torch.Size((boxes3d.shape[0], 5)))
......@@ -54,59 +67,26 @@ def boxes3d_to_bevboxes_lidar_torch(boxes3d):
return boxes_bev
def boxes_iou3d_gpu(boxes_a, boxes_b):
"""
:param boxes_a: (N, 7) [x, y, z, w, l, h, ry] in LiDAR
:param boxes_b: (M, 7) [x, y, z, h, w, l, ry]
:return:
ans_iou: (M, N)
"""
boxes_a_bev = boxes3d_to_bevboxes_lidar_torch(boxes_a)
boxes_b_bev = boxes3d_to_bevboxes_lidar_torch(boxes_b)
# height overlap
boxes_a_height_max = (boxes_a[:, 2] + boxes_a[:, 5]).view(-1, 1)
boxes_a_height_min = boxes_a[:, 2].view(-1, 1)
boxes_b_height_max = (boxes_b[:, 2] + boxes_b[:, 5]).view(1, -1)
boxes_b_height_min = boxes_b[:, 2].view(1, -1)
# bev overlap
overlaps_bev = boxes_a.new_zeros(
torch.Size((boxes_a.shape[0], boxes_b.shape[0]))) # (N, M)
iou3d_cuda.boxes_overlap_bev_gpu(boxes_a_bev.contiguous(),
boxes_b_bev.contiguous(), overlaps_bev)
max_of_min = torch.max(boxes_a_height_min, boxes_b_height_min)
min_of_max = torch.min(boxes_a_height_max, boxes_b_height_max)
overlaps_h = torch.clamp(min_of_max - max_of_min, min=0)
# 3d iou
overlaps_3d = overlaps_bev * overlaps_h
vol_a = (boxes_a[:, 3] * boxes_a[:, 4] * boxes_a[:, 5]).view(-1, 1)
vol_b = (boxes_b[:, 3] * boxes_b[:, 4] * boxes_b[:, 5]).view(1, -1)
iou3d = overlaps_3d / torch.clamp(vol_a + vol_b - overlaps_3d, min=1e-6)
def get_iou_gpu(bb1, bb2):
"""Get IoU.
return iou3d
Compute IoU of two bounding boxes.
Args:
bb1 (ndarray): [x, y, z, w, l, h, ry] in LiDAR.
bb2 (ndarray): [x, y, z, h, w, l, ry] in LiDAR.
def get_iou_gpu(bb1, bb2):
""" Compute IoU of two bounding boxes.
** Define your bod IoU function HERE **
Returns:
ans_iou (tensor): The answer of IoU.
"""
bb1 = torch.from_numpy(bb1).float().cuda()
bb2 = torch.from_numpy(bb2).float().cuda()
iou3d = boxes_iou3d_gpu(bb1, bb2)
iou3d = bbox_overlaps_3d(bb1, bb2, mode='iou', coordinate='lidar')
return iou3d.cpu().numpy()
def eval_det_cls(pred,
gt,
ovthresh=None,
use_07_metric=False,
get_iou_func=get_iou_gpu):
def eval_det_cls(pred, gt, ovthresh=None, use_07_metric=False):
""" Generic functions to compute precision/recall for object detection
for a single class.
Input:
......@@ -161,11 +141,9 @@ def eval_det_cls(pred,
ious.append(np.zeros(1))
confidence = np.array(confidence)
BB = np.array(BB) # (nd,4 or 8,3 or 6)
# sort by confidence
sorted_ind = np.argsort(-confidence)
BB = BB[sorted_ind, ...]
image_ids = [image_ids[x] for x in sorted_ind]
ious = [ious[x] for x in sorted_ind]
......@@ -214,27 +192,34 @@ def eval_det_cls(pred,
def eval_det_cls_wrapper(arguments):
pred, gt, ovthresh, use_07_metric, get_iou_func = arguments
ret = eval_det_cls(pred, gt, ovthresh, use_07_metric, get_iou_func)
pred, gt, ovthresh, use_07_metric = arguments
ret = eval_det_cls(pred, gt, ovthresh, use_07_metric)
return ret
def eval_det_multiprocessing(pred_all,
gt_all,
ovthresh=None,
use_07_metric=False,
get_iou_func=get_iou_gpu):
""" Generic functions to compute precision/recall for object detection
use_07_metric=False):
""" Evaluate Detection Multiprocessing.
Generic functions to compute precision/recall for object detection
for multiple classes.
Input:
pred_all: map of {img_id: [(classname, bbox, score)]}
gt_all: map of {img_id: [(classname, bbox)]}
ovthresh: a list, iou threshold
use_07_metric: bool, if true use VOC07 11 point method
Output:
rec: {classname: rec}
prec: {classname: prec_all}
ap: {classname: scalar}
Args:
pred_all (dict): map of {img_id: [(classname, bbox, score)]}.
gt_all (dict): map of {img_id: [(classname, bbox)]}.
ovthresh (List[float]): iou threshold.
Default: None.
use_07_metric (bool): if true use VOC07 11 point method.
Default: False.
get_iou_func (func): The function to get iou.
Default: get_iou_gpu.
Return:
rec (dict): {classname: rec}.
prec (dict): {classname: prec_all}.
ap (dict): {classname: scalar}.
"""
pred = {} # map {classname: pred}
gt = {} # map {classname: gt}
......@@ -258,8 +243,8 @@ def eval_det_multiprocessing(pred_all,
gt[classname][img_id].append(bbox)
ret_values = []
args = [(pred[classname], gt[classname], ovthresh, use_07_metric,
get_iou_func) for classname in gt.keys() if classname in pred]
args = [(pred[classname], gt[classname], ovthresh, use_07_metric)
for classname in gt.keys() if classname in pred]
rec = [{} for i in ovthresh]
prec = [{} for i in ovthresh]
ap = [{} for i in ovthresh]
......@@ -281,29 +266,33 @@ def eval_det_multiprocessing(pred_all,
class APCalculator(object):
''' Calculating Average Precision '''
"""AP Calculator.
Calculating Average Precision.
def __init__(self, ap_iou_thresh=None, class2type_map=None):
"""
Args:
ap_iou_thresh: a list, which contains float between 0 and 1.0
ap_iou_thresh (List[float]): a list,
which contains float between 0 and 1.0
IoU threshold to judge whether a prediction is positive.
class2type_map: [optional] dict {class_int:class_name}
class2type_map (dict): {class_int:class_name}.
"""
def __init__(self, ap_iou_thresh=None, class2type_map=None):
self.ap_iou_thresh = ap_iou_thresh
self.class2type_map = class2type_map
self.reset()
def step(self, det_infos, gt_infos):
""" Accumulate one batch of prediction and groundtruth.
""" Step.
Accumulate one batch of prediction and groundtruth.
Args:
batch_pred_map_cls: a list of lists
[[(pred_cls, pred_box_params, score),...],...]
batch_gt_map_cls: a list of lists
[[(gt_cls, gt_box_params),...],...]
should have the same length with batch_pred_map_cls
(batch_size)
batch_pred_map_cls (List[List]): a list of lists
[[(pred_cls, pred_box_params, score),...],...].
batch_gt_map_cls (List[List]): a list of lists
[[(gt_cls, gt_box_params),...],...].
"""
# cache pred infos
for batch_pred_map_cls in det_infos:
......@@ -322,13 +311,9 @@ class APCalculator(object):
self.scan_cnt += 1
def compute_metrics(self):
""" Use accumulated predictions and groundtruths to compute Average Precision.
"""
recs, precs, aps = eval_det_multiprocessing(
self.pred_map_cls,
self.gt_map_cls,
ovthresh=self.ap_iou_thresh,
get_iou_func=get_iou_gpu)
self.pred_map_cls, self.gt_map_cls, ovthresh=self.ap_iou_thresh)
ret = []
for i, iou_thresh in enumerate(self.ap_iou_thresh):
ret_dict = {}
......@@ -347,7 +332,7 @@ class APCalculator(object):
ret_dict['%s Recall %d' %
(clsname, iou_thresh * 100)] = rec[key][-1]
rec_list.append(rec[key][-1])
except KeyError:
except TypeError:
ret_dict['%s Recall %d' % (clsname, iou_thresh * 100)] = 0
rec_list.append(0)
ret_dict['AR%d' % (iou_thresh * 100)] = np.mean(rec_list)
......@@ -361,10 +346,15 @@ class APCalculator(object):
def boxes3d_depth_to_lidar(boxes3d, mid_to_bottom=True):
""" Flip X-right,Y-forward,Z-up to X-forward,Y-left,Z-up
:param boxes3d_depth: (N, 7) [x, y, z, w, l, h, r] in depth coords
:return:
boxes3d_lidar: (N, 7) [x, y, z, l, h, w, r] in LiDAR coords
""" Boxes3d Depth to Lidar.
Flip X-right,Y-forward,Z-up to X-forward,Y-left,Z-up.
Args:
boxes3d (ndarray): (N, 7) [x, y, z, w, l, h, r] in depth coords.
Return:
boxes3d_lidar (ndarray): (N, 7) [x, y, z, l, h, w, r] in LiDAR coords.
"""
boxes3d_lidar = boxes3d.copy()
boxes3d_lidar[..., [0, 1, 2, 3, 4, 5]] = boxes3d_lidar[...,
......@@ -376,6 +366,21 @@ def boxes3d_depth_to_lidar(boxes3d, mid_to_bottom=True):
def scannet_eval(gt_annos, dt_annos, metric, class2type):
"""Scannet Evaluation.
Evaluate the result of the detection.
Args:
gt_annos (List): GT annotations.
dt_annos (List): Detection annotations.
metric (dict): AP IoU thresholds.
class2type (dict): {class: type}.
Return:
result_str (str): Result string.
metrics_dict (dict): Result.
"""
for gt_anno in gt_annos:
if gt_anno['gt_num'] != 0:
# convert to lidar coor for evaluation
......@@ -384,7 +389,7 @@ def scannet_eval(gt_annos, dt_annos, metric, class2type):
gt_anno['gt_boxes_upright_depth'] = np.pad(bbox_lidar_bottom,
((0, 0), (0, 1)),
'constant')
ap_iou_thresholds = metric.AP_IOU_THRESHHOLDS
ap_iou_thresholds = metric['AP_IOU_THRESHHOLDS']
ap_calculator = APCalculator(ap_iou_thresholds, class2type)
ap_calculator.step(dt_annos, gt_annos)
result_str = str()
......
......@@ -187,7 +187,7 @@ class ScannetDataset(torch_data.Dataset):
pred_boxes = output[i]
box3d_depth = pred_boxes['box3d_lidar']
if box3d_depth is not None:
label_preds = pred_boxes.get['label_preds']
label_preds = pred_boxes['label_preds']
scores = pred_boxes['scores'].detach().cpu().numpy()
label_preds = label_preds.detach().cpu().numpy()
num_proposal = box3d_depth.shape[0]
......@@ -216,7 +216,8 @@ class ScannetDataset(torch_data.Dataset):
gt_annos = [
copy.deepcopy(info['annos']) for info in self.scannet_infos
]
ap_result_str, ap_dict = scannet_eval(gt_annos, results)
ap_result_str, ap_dict = scannet_eval(gt_annos, results, metric,
self.class2type)
return ap_dict
def __len__(self):
......
import numpy as np
import torch
from mmdet3d.datasets.scannet_dataset import ScannetDataset
......@@ -68,3 +69,49 @@ def test_getitem():
assert np.all(gt_labels.numpy() == expected_gt_labels)
assert np.all(pts_semantic_mask == expected_pts_semantic_mask)
assert np.all(pts_instance_mask == expected_pts_instance_mask)
def test_evaluate():
root_path = './tests/data/scannet'
ann_file = './tests/data/scannet/scannet_infos.pkl'
scannet_dataset = ScannetDataset(root_path, ann_file)
results = []
pred_boxes = dict()
pred_boxes['box3d_lidar'] = np.array([[
3.52074146e+00, -1.48129511e+00, 1.57035351e+00, 2.31956959e-01,
1.74445975e+00, 5.72351933e-01, 0
],
[
-3.48033905e+00, -2.90395617e+00,
1.19105673e+00, 1.70723915e-01,
6.60776615e-01, 6.71535969e-01, 0
],
[
2.19867110e+00, -1.14655101e+00,
9.25755501e-03, 2.53463078e+00,
5.41841269e-01, 1.21447623e+00, 0
],
[
2.50163722, -2.91681337,
0.82875049, 1.84280431,
0.61697435, 0.28697443, 0
],
[
-0.01335114, 3.3114481,
-0.00895238, 3.85815716,
0.44081616, 2.16034412, 0
]])
pred_boxes['label_preds'] = torch.Tensor([6, 6, 4, 9, 11]).cuda()
pred_boxes['scores'] = torch.Tensor([0.5, 1.0, 1.0, 1.0, 1.0]).cuda()
results.append([pred_boxes])
metric = dict()
metric['AP_IOU_THRESHHOLDS'] = [0.25, 0.5]
ap_dict = scannet_dataset.evaluate(results, metric)
table_average_precision_25 = ap_dict['table Average Precision 25']
window_average_precision_25 = ap_dict['window Average Precision 25']
counter_average_precision_25 = ap_dict['counter Average Precision 25']
curtain_average_precision_25 = ap_dict['curtain Average Precision 25']
assert abs(table_average_precision_25 - 0.3333) < 0.01
assert abs(window_average_precision_25 - 1) < 0.01
assert abs(counter_average_precision_25 - 1) < 0.01
assert abs(curtain_average_precision_25 - 0.5) < 0.01
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment