Commit 955b4419 authored by VVsssssk's avatar VVsssssk Committed by ChaimZhu
Browse files

[Refactor]Add kitti metric

parent 7e87d837
...@@ -282,12 +282,13 @@ def compute_statistics_jit(overlaps, ...@@ -282,12 +282,13 @@ def compute_statistics_jit(overlaps,
def get_split_parts(num, num_part): def get_split_parts(num, num_part):
same_part = num // num_part if num % num_part == 0:
remain_num = num % num_part same_part = num // num_part
if remain_num == 0:
return [same_part] * num_part return [same_part] * num_part
else: else:
return [same_part] * num_part + [remain_num] same_part = num // (num_part - 1)
remain_num = num % (num_part - 1)
return [same_part] * (num_part - 1) + [remain_num]
@numba.jit(nopython=True) @numba.jit(nopython=True)
...@@ -340,57 +341,57 @@ def fused_compute_statistics(overlaps, ...@@ -340,57 +341,57 @@ def fused_compute_statistics(overlaps,
dc_num += dc_nums[i] dc_num += dc_nums[i]
def calculate_iou_partly(gt_annos, dt_annos, metric, num_parts=50): def calculate_iou_partly(dt_annos, gt_annos, metric, num_parts=50):
"""Fast iou algorithm. this function can be used independently to do result """Fast iou algorithm. this function can be used independently to do result
analysis. Must be used in CAMERA coordinate system. analysis. Must be used in CAMERA coordinate system.
Args: Args:
gt_annos (dict): Must from get_label_annos() in kitti_common.py.
dt_annos (dict): Must from get_label_annos() in kitti_common.py. dt_annos (dict): Must from get_label_annos() in kitti_common.py.
gt_annos (dict): Must from get_label_annos() in kitti_common.py.
metric (int): Eval type. 0: bbox, 1: bev, 2: 3d. metric (int): Eval type. 0: bbox, 1: bev, 2: 3d.
num_parts (int): A parameter for fast calculate algorithm. num_parts (int): A parameter for fast calculate algorithm.
""" """
assert len(gt_annos) == len(dt_annos) assert len(dt_annos) == len(gt_annos)
total_dt_num = np.stack([len(a['name']) for a in dt_annos], 0) total_dt_num = np.stack([len(a['name']) for a in gt_annos], 0)
total_gt_num = np.stack([len(a['name']) for a in gt_annos], 0) total_gt_num = np.stack([len(a['name']) for a in dt_annos], 0)
num_examples = len(gt_annos) num_examples = len(dt_annos)
split_parts = get_split_parts(num_examples, num_parts) split_parts = get_split_parts(num_examples, num_parts)
parted_overlaps = [] parted_overlaps = []
example_idx = 0 example_idx = 0
for num_part in split_parts: for num_part in split_parts:
gt_annos_part = gt_annos[example_idx:example_idx + num_part]
dt_annos_part = dt_annos[example_idx:example_idx + num_part] dt_annos_part = dt_annos[example_idx:example_idx + num_part]
gt_annos_part = gt_annos[example_idx:example_idx + num_part]
if metric == 0: if metric == 0:
gt_boxes = np.concatenate([a['bbox'] for a in gt_annos_part], 0) gt_boxes = np.concatenate([a['bbox'] for a in dt_annos_part], 0)
dt_boxes = np.concatenate([a['bbox'] for a in dt_annos_part], 0) dt_boxes = np.concatenate([a['bbox'] for a in gt_annos_part], 0)
overlap_part = image_box_overlap(gt_boxes, dt_boxes) overlap_part = image_box_overlap(gt_boxes, dt_boxes)
elif metric == 1: elif metric == 1:
loc = np.concatenate( loc = np.concatenate(
[a['location'][:, [0, 2]] for a in gt_annos_part], 0) [a['location'][:, [0, 2]] for a in dt_annos_part], 0)
dims = np.concatenate( dims = np.concatenate(
[a['dimensions'][:, [0, 2]] for a in gt_annos_part], 0) [a['dimensions'][:, [0, 2]] for a in dt_annos_part], 0)
rots = np.concatenate([a['rotation_y'] for a in gt_annos_part], 0) rots = np.concatenate([a['rotation_y'] for a in dt_annos_part], 0)
gt_boxes = np.concatenate([loc, dims, rots[..., np.newaxis]], gt_boxes = np.concatenate([loc, dims, rots[..., np.newaxis]],
axis=1) axis=1)
loc = np.concatenate( loc = np.concatenate(
[a['location'][:, [0, 2]] for a in dt_annos_part], 0) [a['location'][:, [0, 2]] for a in gt_annos_part], 0)
dims = np.concatenate( dims = np.concatenate(
[a['dimensions'][:, [0, 2]] for a in dt_annos_part], 0) [a['dimensions'][:, [0, 2]] for a in gt_annos_part], 0)
rots = np.concatenate([a['rotation_y'] for a in dt_annos_part], 0) rots = np.concatenate([a['rotation_y'] for a in gt_annos_part], 0)
dt_boxes = np.concatenate([loc, dims, rots[..., np.newaxis]], dt_boxes = np.concatenate([loc, dims, rots[..., np.newaxis]],
axis=1) axis=1)
overlap_part = bev_box_overlap(gt_boxes, overlap_part = bev_box_overlap(gt_boxes,
dt_boxes).astype(np.float64) dt_boxes).astype(np.float64)
elif metric == 2: elif metric == 2:
loc = np.concatenate([a['location'] for a in gt_annos_part], 0)
dims = np.concatenate([a['dimensions'] for a in gt_annos_part], 0)
rots = np.concatenate([a['rotation_y'] for a in gt_annos_part], 0)
gt_boxes = np.concatenate([loc, dims, rots[..., np.newaxis]],
axis=1)
loc = np.concatenate([a['location'] for a in dt_annos_part], 0) loc = np.concatenate([a['location'] for a in dt_annos_part], 0)
dims = np.concatenate([a['dimensions'] for a in dt_annos_part], 0) dims = np.concatenate([a['dimensions'] for a in dt_annos_part], 0)
rots = np.concatenate([a['rotation_y'] for a in dt_annos_part], 0) rots = np.concatenate([a['rotation_y'] for a in dt_annos_part], 0)
gt_boxes = np.concatenate([loc, dims, rots[..., np.newaxis]],
axis=1)
loc = np.concatenate([a['location'] for a in gt_annos_part], 0)
dims = np.concatenate([a['dimensions'] for a in gt_annos_part], 0)
rots = np.concatenate([a['rotation_y'] for a in gt_annos_part], 0)
dt_boxes = np.concatenate([loc, dims, rots[..., np.newaxis]], dt_boxes = np.concatenate([loc, dims, rots[..., np.newaxis]],
axis=1) axis=1)
overlap_part = d3_box_overlap(gt_boxes, overlap_part = d3_box_overlap(gt_boxes,
...@@ -402,8 +403,8 @@ def calculate_iou_partly(gt_annos, dt_annos, metric, num_parts=50): ...@@ -402,8 +403,8 @@ def calculate_iou_partly(gt_annos, dt_annos, metric, num_parts=50):
overlaps = [] overlaps = []
example_idx = 0 example_idx = 0
for j, num_part in enumerate(split_parts): for j, num_part in enumerate(split_parts):
gt_annos_part = gt_annos[example_idx:example_idx + num_part]
dt_annos_part = dt_annos[example_idx:example_idx + num_part] dt_annos_part = dt_annos[example_idx:example_idx + num_part]
gt_annos_part = gt_annos[example_idx:example_idx + num_part]
gt_num_idx, dt_num_idx = 0, 0 gt_num_idx, dt_num_idx = 0, 0
for i in range(num_part): for i in range(num_part):
gt_box_num = total_gt_num[example_idx + i] gt_box_num = total_gt_num[example_idx + i]
...@@ -480,6 +481,7 @@ def eval_class(gt_annos, ...@@ -480,6 +481,7 @@ def eval_class(gt_annos,
rets = calculate_iou_partly(dt_annos, gt_annos, metric, num_parts) rets = calculate_iou_partly(dt_annos, gt_annos, metric, num_parts)
overlaps, parted_overlaps, total_dt_num, total_gt_num = rets overlaps, parted_overlaps, total_dt_num, total_gt_num = rets
N_SAMPLE_PTS = 41 N_SAMPLE_PTS = 41
num_minoverlap = len(min_overlaps) num_minoverlap = len(min_overlaps)
num_class = len(current_classes) num_class = len(current_classes)
......
# Copyright (c) OpenMMLab. All rights reserved.
from .kitti_metric import KittiMetric # noqa: F401,F403
__all_ = ['KittiMetric']
This diff is collapsed.
import numpy as np
import pytest
import torch
from mmengine.data import InstanceData
from mmdet3d.core import Det3DDataSample
from mmdet3d.core.bbox import LiDARInstance3DBoxes
from mmdet3d.metrics import KittiMetric
data_root = 'tests/data/kitti'
def _init_evaluate_input():
data_batch = [dict(data_sample=dict(sample_idx=0))]
predictions = Det3DDataSample()
pred_instances_3d = InstanceData()
pred_instances_3d.bboxes_3d = LiDARInstance3DBoxes(
torch.tensor(
[[8.7314, -1.8559, -1.5997, 0.4800, 1.2000, 1.8900, 0.0100]]))
pred_instances_3d.scores_3d = torch.Tensor([0.9])
pred_instances_3d.labels_3d = torch.Tensor([0])
predictions.pred_instances_3d = pred_instances_3d
predictions = predictions.to_dict()
return data_batch, [predictions]
def _init_multi_modal_evaluate_input():
data_batch = [dict(data_sample=dict(sample_idx=0))]
predictions = Det3DDataSample()
pred_instances_3d = InstanceData()
pred_instances = InstanceData()
pred_instances.bboxes = torch.tensor([[712.4, 143, 810.7, 307.92]])
pred_instances.scores = torch.Tensor([0.9])
pred_instances.labels = torch.Tensor([0])
pred_instances_3d.bboxes_3d = LiDARInstance3DBoxes(
torch.tensor(
[[8.7314, -1.8559, -1.5997, 0.4800, 1.2000, 1.8900, 0.0100]]))
pred_instances_3d.scores_3d = torch.Tensor([0.9])
pred_instances_3d.labels_3d = torch.Tensor([0])
predictions.pred_instances_3d = pred_instances_3d
predictions.pred_instances = pred_instances
predictions = predictions.to_dict()
return data_batch, [predictions]
def test_multi_modal_kitti_metric():
if not torch.cuda.is_available():
pytest.skip('test requires GPU and torch+cuda')
kittimetric = KittiMetric(
data_root + '/kitti_infos_train.pkl', metric=['mAP'])
kittimetric.dataset_meta = dict(CLASSES=['Car', 'Pedestrian', 'Cyclist'])
data_batch, predictions = _init_multi_modal_evaluate_input()
kittimetric.process(data_batch, predictions)
ap_dict = kittimetric.compute_metrics(kittimetric.results)
assert np.isclose(ap_dict['pred_instances_3d/KITTI/Overall_3D_AP11_easy'],
3.0303030303030307)
assert np.isclose(ap_dict['pred_instances_3d/KITTI/Overall_BEV_AP11_easy'],
3.0303030303030307)
assert np.isclose(ap_dict['pred_instances_3d/KITTI/Overall_2D_AP11_easy'],
3.0303030303030307)
assert np.isclose(ap_dict['pred_instances/KITTI/Overall_2D_AP11_easy'],
3.0303030303030307)
assert np.isclose(ap_dict['pred_instances/KITTI/Overall_2D_AP11_moderate'],
3.0303030303030307)
assert np.isclose(ap_dict['pred_instances/KITTI/Overall_2D_AP11_hard'],
3.0303030303030307)
def test_kitti_metric_mAP():
if not torch.cuda.is_available():
pytest.skip('test requires GPU and torch+cuda')
kittimetric = KittiMetric(
data_root + '/kitti_infos_train.pkl', metric=['mAP'])
kittimetric.dataset_meta = dict(CLASSES=['Car', 'Pedestrian', 'Cyclist'])
data_batch, predictions = _init_evaluate_input()
kittimetric.process(data_batch, predictions)
ap_dict = kittimetric.compute_metrics(kittimetric.results)
assert np.isclose(ap_dict['pred_instances_3d/KITTI/Overall_3D_AP11_easy'],
3.0303030303030307)
assert np.isclose(
ap_dict['pred_instances_3d/KITTI/Overall_3D_AP11_moderate'],
3.0303030303030307)
assert np.isclose(ap_dict['pred_instances_3d/KITTI/Overall_3D_AP11_hard'],
3.0303030303030307)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment