Commit 41db4eae authored by VVsssssk's avatar VVsssssk Committed by ChaimZhu
Browse files

[Fix]Fix about kitti pp bug

parent c9ad3605
......@@ -2,6 +2,7 @@ voxel_size = [0.16, 0.16, 4]
model = dict(
type='VoxelNet',
data_preprocessor=dict(type='Det3DDataPreprocessor'),
voxel_layer=dict(
max_num_points=32, # max_points_per_voxel
point_cloud_range=[0, -39.68, -3, 69.12, 39.68, 1],
......
......@@ -26,7 +26,6 @@ model = dict(
allowed_border=0,
pos_weight=-1,
debug=False))
# dataset settings
dataset_type = 'KittiDataset'
data_root = 'data/kitti/'
......@@ -52,7 +51,6 @@ train_pipeline = [
dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range),
dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
dict(type='PointShuffle'),
dict(type='DefaultFormatBundle3D', class_names=class_names),
dict(
type='Pack3DDetInputs',
keys=['points', 'gt_labels_3d', 'gt_bboxes_3d'])
......@@ -82,4 +80,5 @@ train_dataloader = dict(
type='RepeatDataset',
times=2,
dataset=dict(pipeline=train_pipeline, metainfo=metainfo)))
test_dataloader = dict(dataset=dict(pipeline=test_pipeline, metainfo=metainfo))
test_dataloader = dict(dataset=dict(metainfo=metainfo))
val_dataloader = dict(dataset=dict(metainfo=metainfo))
......@@ -352,8 +352,8 @@ def calculate_iou_partly(dt_annos, gt_annos, metric, num_parts=50):
num_parts (int): A parameter for fast calculate algorithm.
"""
assert len(dt_annos) == len(gt_annos)
total_dt_num = np.stack([len(a['name']) for a in gt_annos], 0)
total_gt_num = np.stack([len(a['name']) for a in dt_annos], 0)
total_gt_num = np.stack([len(a['name']) for a in gt_annos], 0)
total_dt_num = np.stack([len(a['name']) for a in dt_annos], 0)
num_examples = len(dt_annos)
split_parts = get_split_parts(num_examples, num_parts)
parted_overlaps = []
......@@ -363,39 +363,39 @@ def calculate_iou_partly(dt_annos, gt_annos, metric, num_parts=50):
dt_annos_part = dt_annos[example_idx:example_idx + num_part]
gt_annos_part = gt_annos[example_idx:example_idx + num_part]
if metric == 0:
gt_boxes = np.concatenate([a['bbox'] for a in dt_annos_part], 0)
dt_boxes = np.concatenate([a['bbox'] for a in gt_annos_part], 0)
overlap_part = image_box_overlap(gt_boxes, dt_boxes)
dt_boxes = np.concatenate([a['bbox'] for a in dt_annos_part], 0)
gt_boxes = np.concatenate([a['bbox'] for a in gt_annos_part], 0)
overlap_part = image_box_overlap(dt_boxes, gt_boxes)
elif metric == 1:
loc = np.concatenate(
[a['location'][:, [0, 2]] for a in dt_annos_part], 0)
dims = np.concatenate(
[a['dimensions'][:, [0, 2]] for a in dt_annos_part], 0)
rots = np.concatenate([a['rotation_y'] for a in dt_annos_part], 0)
gt_boxes = np.concatenate([loc, dims, rots[..., np.newaxis]],
dt_boxes = np.concatenate([loc, dims, rots[..., np.newaxis]],
axis=1)
loc = np.concatenate(
[a['location'][:, [0, 2]] for a in gt_annos_part], 0)
dims = np.concatenate(
[a['dimensions'][:, [0, 2]] for a in gt_annos_part], 0)
rots = np.concatenate([a['rotation_y'] for a in gt_annos_part], 0)
dt_boxes = np.concatenate([loc, dims, rots[..., np.newaxis]],
gt_boxes = np.concatenate([loc, dims, rots[..., np.newaxis]],
axis=1)
overlap_part = bev_box_overlap(gt_boxes,
dt_boxes).astype(np.float64)
overlap_part = bev_box_overlap(dt_boxes,
gt_boxes).astype(np.float64)
elif metric == 2:
loc = np.concatenate([a['location'] for a in dt_annos_part], 0)
dims = np.concatenate([a['dimensions'] for a in dt_annos_part], 0)
rots = np.concatenate([a['rotation_y'] for a in dt_annos_part], 0)
gt_boxes = np.concatenate([loc, dims, rots[..., np.newaxis]],
dt_boxes = np.concatenate([loc, dims, rots[..., np.newaxis]],
axis=1)
loc = np.concatenate([a['location'] for a in gt_annos_part], 0)
dims = np.concatenate([a['dimensions'] for a in gt_annos_part], 0)
rots = np.concatenate([a['rotation_y'] for a in gt_annos_part], 0)
dt_boxes = np.concatenate([loc, dims, rots[..., np.newaxis]],
gt_boxes = np.concatenate([loc, dims, rots[..., np.newaxis]],
axis=1)
overlap_part = d3_box_overlap(gt_boxes,
dt_boxes).astype(np.float64)
overlap_part = d3_box_overlap(dt_boxes,
gt_boxes).astype(np.float64)
else:
raise ValueError('unknown metric')
parted_overlaps.append(overlap_part)
......@@ -403,20 +403,18 @@ def calculate_iou_partly(dt_annos, gt_annos, metric, num_parts=50):
overlaps = []
example_idx = 0
for j, num_part in enumerate(split_parts):
dt_annos_part = dt_annos[example_idx:example_idx + num_part]
gt_annos_part = gt_annos[example_idx:example_idx + num_part]
gt_num_idx, dt_num_idx = 0, 0
for i in range(num_part):
gt_box_num = total_gt_num[example_idx + i]
dt_box_num = total_dt_num[example_idx + i]
overlaps.append(
parted_overlaps[j][gt_num_idx:gt_num_idx + gt_box_num,
dt_num_idx:dt_num_idx + dt_box_num])
parted_overlaps[j][dt_num_idx:dt_num_idx + dt_box_num,
gt_num_idx:gt_num_idx + gt_box_num])
gt_num_idx += gt_box_num
dt_num_idx += dt_box_num
example_idx += num_part
return overlaps, parted_overlaps, total_gt_num, total_dt_num
return overlaps, parted_overlaps, total_dt_num, total_gt_num
def _prepare_data(gt_annos, dt_annos, current_class, difficulty):
......
......@@ -39,7 +39,10 @@ class KittiDataset(Det3DDataset):
Default: [0, -40, -3, 70.4, 40, 0.0].
"""
# TODO: use full classes of kitti
METAINFO = {'CLASSES': ('Pedestrian', 'Cyclist', 'Car')}
METAINFO = {
'CLASSES': ('Pedestrian', 'Cyclist', 'Car', 'Van', 'Truck',
'Person_sitting', 'Tram', 'Misc')
}
def __init__(self,
data_root: str,
......
......@@ -64,13 +64,21 @@ class KittiMetric(BaseMetric):
raise KeyError("metric should be one of 'bbox', 'img_bbox', "
'but got {metric}.')
def convert_annos_to_kitti_annos(self, data_annos: list,
classes: list) -> list:
def convert_annos_to_kitti_annos(
self,
data_annos: list,
classes: list = [
'Pedestrian', 'Cyclist', 'Car', 'Van', 'Truck', 'Person_sitting',
'Tram', 'Misc'
]
) -> list:
"""Convert loading annotations to Kitti annotations.
Args:
data_annos (list[dict]): Annotations loaded from ann_file.
classes (list[str]): Classes used in the dataset.
classes (list[str]): Classes used in the dataset. Default used
['Pedestrian', 'Cyclist', 'Car', 'Van', 'Truck',
'Person_sitting', 'Tram', 'Misc'].
Returns:
List[dict]: List of Kitti annotations.
......@@ -102,7 +110,10 @@ class KittiMetric(BaseMetric):
'score': []
}
for instance in annos['instances']:
kitti_annos['name'].append(classes[instance['bbox_label']])
labels = instance['bbox_label']
if labels == -1:
continue
kitti_annos['name'].append(classes[labels])
kitti_annos['truncated'].append(instance['truncated'])
kitti_annos['occluded'].append(instance['occluded'])
kitti_annos['alpha'].append(instance['alpha'])
......@@ -148,7 +159,7 @@ class KittiMetric(BaseMetric):
for pred_result in pred:
for attr_name in pred[pred_result]:
pred[pred_result][attr_name] = pred[pred_result][
attr_name].to(self.collect_device)
attr_name].to('cpu')
result[pred_result] = pred[pred_result]
sample_idx = data['data_sample']['sample_idx']
result['sample_idx'] = sample_idx
......@@ -165,14 +176,11 @@ class KittiMetric(BaseMetric):
the metrics, and the values are corresponding results.
"""
logger: MMLogger = MMLogger.get_current_instance()
self.classes = self.dataset_meta['CLASSES']
# load annotations
pkl_annos = self.load_annotations(self.ann_file)['data_list']
self.data_infos = self.convert_annos_to_kitti_annos(
pkl_annos, self.classes)
self.data_infos = self.convert_annos_to_kitti_annos(pkl_annos)
result_dict, tmp_dir = self.format_results(
results,
pklfile_prefix=self.pklfile_prefix,
......@@ -200,7 +208,7 @@ class KittiMetric(BaseMetric):
return metric_dict
def kitti_evaluate(self,
result_dict: List[dict],
results_dict: List[dict],
gt_annos: List[dict],
metric: str = None,
classes: List[str] = None,
......@@ -221,13 +229,13 @@ class KittiMetric(BaseMetric):
dict[str, float]: Results of each evaluation metric.
"""
ap_dict = dict()
for name in result_dict:
for name in results_dict:
if name == 'pred_instances' or metric == 'img_bbox':
eval_types = ['bbox']
else:
eval_types = ['bbox', 'bev', '3d']
ap_result_str, ap_dict_ = kitti_eval(
gt_annos, result_dict[name], classes, eval_types=eval_types)
gt_annos, results_dict[name], classes, eval_types=eval_types)
for ap_type, ap in ap_dict_.items():
ap_dict[f'{name}/{ap_type}'] = float('{:.4f}'.format(ap))
......
......@@ -212,7 +212,8 @@ def update_kitti_infos(pkl_path, out_dir):
# TODO update to full label
# TODO discuss how to process 'Van', 'DontCare'
METAINFO = {
'CLASSES': ('Pedestrian', 'Cyclist', 'Car'),
'CLASSES': ('Pedestrian', 'Cyclist', 'Car', 'Van', 'Truck',
'Person_sitting', 'Tram', 'Misc'),
}
print(f'Reading from input file: {pkl_path}.')
data_list = mmcv.load(pkl_path)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment