Commit 955b4419 authored by VVsssssk's avatar VVsssssk Committed by ChaimZhu
Browse files

[Refactor]Add kitti metric

parent 7e87d837
......@@ -282,12 +282,13 @@ def compute_statistics_jit(overlaps,
def get_split_parts(num, num_part):
same_part = num // num_part
remain_num = num % num_part
if remain_num == 0:
if num % num_part == 0:
same_part = num // num_part
return [same_part] * num_part
else:
return [same_part] * num_part + [remain_num]
same_part = num // (num_part - 1)
remain_num = num % (num_part - 1)
return [same_part] * (num_part - 1) + [remain_num]
@numba.jit(nopython=True)
......@@ -340,57 +341,57 @@ def fused_compute_statistics(overlaps,
dc_num += dc_nums[i]
def calculate_iou_partly(gt_annos, dt_annos, metric, num_parts=50):
def calculate_iou_partly(dt_annos, gt_annos, metric, num_parts=50):
"""Fast iou algorithm. this function can be used independently to do result
analysis. Must be used in CAMERA coordinate system.
Args:
gt_annos (dict): Must from get_label_annos() in kitti_common.py.
dt_annos (dict): Must from get_label_annos() in kitti_common.py.
gt_annos (dict): Must from get_label_annos() in kitti_common.py.
metric (int): Eval type. 0: bbox, 1: bev, 2: 3d.
num_parts (int): A parameter for fast calculate algorithm.
"""
assert len(gt_annos) == len(dt_annos)
total_dt_num = np.stack([len(a['name']) for a in dt_annos], 0)
total_gt_num = np.stack([len(a['name']) for a in gt_annos], 0)
num_examples = len(gt_annos)
assert len(dt_annos) == len(gt_annos)
total_dt_num = np.stack([len(a['name']) for a in gt_annos], 0)
total_gt_num = np.stack([len(a['name']) for a in dt_annos], 0)
num_examples = len(dt_annos)
split_parts = get_split_parts(num_examples, num_parts)
parted_overlaps = []
example_idx = 0
for num_part in split_parts:
gt_annos_part = gt_annos[example_idx:example_idx + num_part]
dt_annos_part = dt_annos[example_idx:example_idx + num_part]
gt_annos_part = gt_annos[example_idx:example_idx + num_part]
if metric == 0:
gt_boxes = np.concatenate([a['bbox'] for a in gt_annos_part], 0)
dt_boxes = np.concatenate([a['bbox'] for a in dt_annos_part], 0)
gt_boxes = np.concatenate([a['bbox'] for a in dt_annos_part], 0)
dt_boxes = np.concatenate([a['bbox'] for a in gt_annos_part], 0)
overlap_part = image_box_overlap(gt_boxes, dt_boxes)
elif metric == 1:
loc = np.concatenate(
[a['location'][:, [0, 2]] for a in gt_annos_part], 0)
[a['location'][:, [0, 2]] for a in dt_annos_part], 0)
dims = np.concatenate(
[a['dimensions'][:, [0, 2]] for a in gt_annos_part], 0)
rots = np.concatenate([a['rotation_y'] for a in gt_annos_part], 0)
[a['dimensions'][:, [0, 2]] for a in dt_annos_part], 0)
rots = np.concatenate([a['rotation_y'] for a in dt_annos_part], 0)
gt_boxes = np.concatenate([loc, dims, rots[..., np.newaxis]],
axis=1)
loc = np.concatenate(
[a['location'][:, [0, 2]] for a in dt_annos_part], 0)
[a['location'][:, [0, 2]] for a in gt_annos_part], 0)
dims = np.concatenate(
[a['dimensions'][:, [0, 2]] for a in dt_annos_part], 0)
rots = np.concatenate([a['rotation_y'] for a in dt_annos_part], 0)
[a['dimensions'][:, [0, 2]] for a in gt_annos_part], 0)
rots = np.concatenate([a['rotation_y'] for a in gt_annos_part], 0)
dt_boxes = np.concatenate([loc, dims, rots[..., np.newaxis]],
axis=1)
overlap_part = bev_box_overlap(gt_boxes,
dt_boxes).astype(np.float64)
elif metric == 2:
loc = np.concatenate([a['location'] for a in gt_annos_part], 0)
dims = np.concatenate([a['dimensions'] for a in gt_annos_part], 0)
rots = np.concatenate([a['rotation_y'] for a in gt_annos_part], 0)
gt_boxes = np.concatenate([loc, dims, rots[..., np.newaxis]],
axis=1)
loc = np.concatenate([a['location'] for a in dt_annos_part], 0)
dims = np.concatenate([a['dimensions'] for a in dt_annos_part], 0)
rots = np.concatenate([a['rotation_y'] for a in dt_annos_part], 0)
gt_boxes = np.concatenate([loc, dims, rots[..., np.newaxis]],
axis=1)
loc = np.concatenate([a['location'] for a in gt_annos_part], 0)
dims = np.concatenate([a['dimensions'] for a in gt_annos_part], 0)
rots = np.concatenate([a['rotation_y'] for a in gt_annos_part], 0)
dt_boxes = np.concatenate([loc, dims, rots[..., np.newaxis]],
axis=1)
overlap_part = d3_box_overlap(gt_boxes,
......@@ -402,8 +403,8 @@ def calculate_iou_partly(gt_annos, dt_annos, metric, num_parts=50):
overlaps = []
example_idx = 0
for j, num_part in enumerate(split_parts):
gt_annos_part = gt_annos[example_idx:example_idx + num_part]
dt_annos_part = dt_annos[example_idx:example_idx + num_part]
gt_annos_part = gt_annos[example_idx:example_idx + num_part]
gt_num_idx, dt_num_idx = 0, 0
for i in range(num_part):
gt_box_num = total_gt_num[example_idx + i]
......@@ -480,6 +481,7 @@ def eval_class(gt_annos,
rets = calculate_iou_partly(dt_annos, gt_annos, metric, num_parts)
overlaps, parted_overlaps, total_dt_num, total_gt_num = rets
N_SAMPLE_PTS = 41
num_minoverlap = len(min_overlaps)
num_class = len(current_classes)
......
# Copyright (c) OpenMMLab. All rights reserved.
from .kitti_metric import KittiMetric # noqa: F401,F403
__all_ = ['KittiMetric']
This diff is collapsed.
import numpy as np
import pytest
import torch
from mmengine.data import InstanceData
from mmdet3d.core import Det3DDataSample
from mmdet3d.core.bbox import LiDARInstance3DBoxes
from mmdet3d.metrics import KittiMetric
data_root = 'tests/data/kitti'
def _init_evaluate_input():
data_batch = [dict(data_sample=dict(sample_idx=0))]
predictions = Det3DDataSample()
pred_instances_3d = InstanceData()
pred_instances_3d.bboxes_3d = LiDARInstance3DBoxes(
torch.tensor(
[[8.7314, -1.8559, -1.5997, 0.4800, 1.2000, 1.8900, 0.0100]]))
pred_instances_3d.scores_3d = torch.Tensor([0.9])
pred_instances_3d.labels_3d = torch.Tensor([0])
predictions.pred_instances_3d = pred_instances_3d
predictions = predictions.to_dict()
return data_batch, [predictions]
def _init_multi_modal_evaluate_input():
data_batch = [dict(data_sample=dict(sample_idx=0))]
predictions = Det3DDataSample()
pred_instances_3d = InstanceData()
pred_instances = InstanceData()
pred_instances.bboxes = torch.tensor([[712.4, 143, 810.7, 307.92]])
pred_instances.scores = torch.Tensor([0.9])
pred_instances.labels = torch.Tensor([0])
pred_instances_3d.bboxes_3d = LiDARInstance3DBoxes(
torch.tensor(
[[8.7314, -1.8559, -1.5997, 0.4800, 1.2000, 1.8900, 0.0100]]))
pred_instances_3d.scores_3d = torch.Tensor([0.9])
pred_instances_3d.labels_3d = torch.Tensor([0])
predictions.pred_instances_3d = pred_instances_3d
predictions.pred_instances = pred_instances
predictions = predictions.to_dict()
return data_batch, [predictions]
def test_multi_modal_kitti_metric():
if not torch.cuda.is_available():
pytest.skip('test requires GPU and torch+cuda')
kittimetric = KittiMetric(
data_root + '/kitti_infos_train.pkl', metric=['mAP'])
kittimetric.dataset_meta = dict(CLASSES=['Car', 'Pedestrian', 'Cyclist'])
data_batch, predictions = _init_multi_modal_evaluate_input()
kittimetric.process(data_batch, predictions)
ap_dict = kittimetric.compute_metrics(kittimetric.results)
assert np.isclose(ap_dict['pred_instances_3d/KITTI/Overall_3D_AP11_easy'],
3.0303030303030307)
assert np.isclose(ap_dict['pred_instances_3d/KITTI/Overall_BEV_AP11_easy'],
3.0303030303030307)
assert np.isclose(ap_dict['pred_instances_3d/KITTI/Overall_2D_AP11_easy'],
3.0303030303030307)
assert np.isclose(ap_dict['pred_instances/KITTI/Overall_2D_AP11_easy'],
3.0303030303030307)
assert np.isclose(ap_dict['pred_instances/KITTI/Overall_2D_AP11_moderate'],
3.0303030303030307)
assert np.isclose(ap_dict['pred_instances/KITTI/Overall_2D_AP11_hard'],
3.0303030303030307)
def test_kitti_metric_mAP():
if not torch.cuda.is_available():
pytest.skip('test requires GPU and torch+cuda')
kittimetric = KittiMetric(
data_root + '/kitti_infos_train.pkl', metric=['mAP'])
kittimetric.dataset_meta = dict(CLASSES=['Car', 'Pedestrian', 'Cyclist'])
data_batch, predictions = _init_evaluate_input()
kittimetric.process(data_batch, predictions)
ap_dict = kittimetric.compute_metrics(kittimetric.results)
assert np.isclose(ap_dict['pred_instances_3d/KITTI/Overall_3D_AP11_easy'],
3.0303030303030307)
assert np.isclose(
ap_dict['pred_instances_3d/KITTI/Overall_3D_AP11_moderate'],
3.0303030303030307)
assert np.isclose(ap_dict['pred_instances_3d/KITTI/Overall_3D_AP11_hard'],
3.0303030303030307)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment