"git@developer.sourcefind.cn:OpenDAS/torch-sparce.git" did not exist on "09a90659d0fffec423a32f8f74d178f08d703558"
Commit 41db4eae authored by VVsssssk's avatar VVsssssk Committed by ChaimZhu
Browse files

[Fix]Fix about kitti pp bug

parent c9ad3605
...@@ -2,6 +2,7 @@ voxel_size = [0.16, 0.16, 4] ...@@ -2,6 +2,7 @@ voxel_size = [0.16, 0.16, 4]
model = dict( model = dict(
type='VoxelNet', type='VoxelNet',
data_preprocessor=dict(type='Det3DDataPreprocessor'),
voxel_layer=dict( voxel_layer=dict(
max_num_points=32, # max_points_per_voxel max_num_points=32, # max_points_per_voxel
point_cloud_range=[0, -39.68, -3, 69.12, 39.68, 1], point_cloud_range=[0, -39.68, -3, 69.12, 39.68, 1],
......
...@@ -26,7 +26,6 @@ model = dict( ...@@ -26,7 +26,6 @@ model = dict(
allowed_border=0, allowed_border=0,
pos_weight=-1, pos_weight=-1,
debug=False)) debug=False))
# dataset settings # dataset settings
dataset_type = 'KittiDataset' dataset_type = 'KittiDataset'
data_root = 'data/kitti/' data_root = 'data/kitti/'
...@@ -52,7 +51,6 @@ train_pipeline = [ ...@@ -52,7 +51,6 @@ train_pipeline = [
dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range), dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range),
dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range), dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
dict(type='PointShuffle'), dict(type='PointShuffle'),
dict(type='DefaultFormatBundle3D', class_names=class_names),
dict( dict(
type='Pack3DDetInputs', type='Pack3DDetInputs',
keys=['points', 'gt_labels_3d', 'gt_bboxes_3d']) keys=['points', 'gt_labels_3d', 'gt_bboxes_3d'])
...@@ -82,4 +80,5 @@ train_dataloader = dict( ...@@ -82,4 +80,5 @@ train_dataloader = dict(
type='RepeatDataset', type='RepeatDataset',
times=2, times=2,
dataset=dict(pipeline=train_pipeline, metainfo=metainfo))) dataset=dict(pipeline=train_pipeline, metainfo=metainfo)))
test_dataloader = dict(dataset=dict(pipeline=test_pipeline, metainfo=metainfo)) test_dataloader = dict(dataset=dict(metainfo=metainfo))
val_dataloader = dict(dataset=dict(metainfo=metainfo))
...@@ -352,8 +352,8 @@ def calculate_iou_partly(dt_annos, gt_annos, metric, num_parts=50): ...@@ -352,8 +352,8 @@ def calculate_iou_partly(dt_annos, gt_annos, metric, num_parts=50):
num_parts (int): A parameter for fast calculate algorithm. num_parts (int): A parameter for fast calculate algorithm.
""" """
assert len(dt_annos) == len(gt_annos) assert len(dt_annos) == len(gt_annos)
total_dt_num = np.stack([len(a['name']) for a in gt_annos], 0) total_gt_num = np.stack([len(a['name']) for a in gt_annos], 0)
total_gt_num = np.stack([len(a['name']) for a in dt_annos], 0) total_dt_num = np.stack([len(a['name']) for a in dt_annos], 0)
num_examples = len(dt_annos) num_examples = len(dt_annos)
split_parts = get_split_parts(num_examples, num_parts) split_parts = get_split_parts(num_examples, num_parts)
parted_overlaps = [] parted_overlaps = []
...@@ -363,39 +363,39 @@ def calculate_iou_partly(dt_annos, gt_annos, metric, num_parts=50): ...@@ -363,39 +363,39 @@ def calculate_iou_partly(dt_annos, gt_annos, metric, num_parts=50):
dt_annos_part = dt_annos[example_idx:example_idx + num_part] dt_annos_part = dt_annos[example_idx:example_idx + num_part]
gt_annos_part = gt_annos[example_idx:example_idx + num_part] gt_annos_part = gt_annos[example_idx:example_idx + num_part]
if metric == 0: if metric == 0:
gt_boxes = np.concatenate([a['bbox'] for a in dt_annos_part], 0) dt_boxes = np.concatenate([a['bbox'] for a in dt_annos_part], 0)
dt_boxes = np.concatenate([a['bbox'] for a in gt_annos_part], 0) gt_boxes = np.concatenate([a['bbox'] for a in gt_annos_part], 0)
overlap_part = image_box_overlap(gt_boxes, dt_boxes) overlap_part = image_box_overlap(dt_boxes, gt_boxes)
elif metric == 1: elif metric == 1:
loc = np.concatenate( loc = np.concatenate(
[a['location'][:, [0, 2]] for a in dt_annos_part], 0) [a['location'][:, [0, 2]] for a in dt_annos_part], 0)
dims = np.concatenate( dims = np.concatenate(
[a['dimensions'][:, [0, 2]] for a in dt_annos_part], 0) [a['dimensions'][:, [0, 2]] for a in dt_annos_part], 0)
rots = np.concatenate([a['rotation_y'] for a in dt_annos_part], 0) rots = np.concatenate([a['rotation_y'] for a in dt_annos_part], 0)
gt_boxes = np.concatenate([loc, dims, rots[..., np.newaxis]], dt_boxes = np.concatenate([loc, dims, rots[..., np.newaxis]],
axis=1) axis=1)
loc = np.concatenate( loc = np.concatenate(
[a['location'][:, [0, 2]] for a in gt_annos_part], 0) [a['location'][:, [0, 2]] for a in gt_annos_part], 0)
dims = np.concatenate( dims = np.concatenate(
[a['dimensions'][:, [0, 2]] for a in gt_annos_part], 0) [a['dimensions'][:, [0, 2]] for a in gt_annos_part], 0)
rots = np.concatenate([a['rotation_y'] for a in gt_annos_part], 0) rots = np.concatenate([a['rotation_y'] for a in gt_annos_part], 0)
dt_boxes = np.concatenate([loc, dims, rots[..., np.newaxis]], gt_boxes = np.concatenate([loc, dims, rots[..., np.newaxis]],
axis=1) axis=1)
overlap_part = bev_box_overlap(gt_boxes, overlap_part = bev_box_overlap(dt_boxes,
dt_boxes).astype(np.float64) gt_boxes).astype(np.float64)
elif metric == 2: elif metric == 2:
loc = np.concatenate([a['location'] for a in dt_annos_part], 0) loc = np.concatenate([a['location'] for a in dt_annos_part], 0)
dims = np.concatenate([a['dimensions'] for a in dt_annos_part], 0) dims = np.concatenate([a['dimensions'] for a in dt_annos_part], 0)
rots = np.concatenate([a['rotation_y'] for a in dt_annos_part], 0) rots = np.concatenate([a['rotation_y'] for a in dt_annos_part], 0)
gt_boxes = np.concatenate([loc, dims, rots[..., np.newaxis]], dt_boxes = np.concatenate([loc, dims, rots[..., np.newaxis]],
axis=1) axis=1)
loc = np.concatenate([a['location'] for a in gt_annos_part], 0) loc = np.concatenate([a['location'] for a in gt_annos_part], 0)
dims = np.concatenate([a['dimensions'] for a in gt_annos_part], 0) dims = np.concatenate([a['dimensions'] for a in gt_annos_part], 0)
rots = np.concatenate([a['rotation_y'] for a in gt_annos_part], 0) rots = np.concatenate([a['rotation_y'] for a in gt_annos_part], 0)
dt_boxes = np.concatenate([loc, dims, rots[..., np.newaxis]], gt_boxes = np.concatenate([loc, dims, rots[..., np.newaxis]],
axis=1) axis=1)
overlap_part = d3_box_overlap(gt_boxes, overlap_part = d3_box_overlap(dt_boxes,
dt_boxes).astype(np.float64) gt_boxes).astype(np.float64)
else: else:
raise ValueError('unknown metric') raise ValueError('unknown metric')
parted_overlaps.append(overlap_part) parted_overlaps.append(overlap_part)
...@@ -403,20 +403,18 @@ def calculate_iou_partly(dt_annos, gt_annos, metric, num_parts=50): ...@@ -403,20 +403,18 @@ def calculate_iou_partly(dt_annos, gt_annos, metric, num_parts=50):
overlaps = [] overlaps = []
example_idx = 0 example_idx = 0
for j, num_part in enumerate(split_parts): for j, num_part in enumerate(split_parts):
dt_annos_part = dt_annos[example_idx:example_idx + num_part]
gt_annos_part = gt_annos[example_idx:example_idx + num_part]
gt_num_idx, dt_num_idx = 0, 0 gt_num_idx, dt_num_idx = 0, 0
for i in range(num_part): for i in range(num_part):
gt_box_num = total_gt_num[example_idx + i] gt_box_num = total_gt_num[example_idx + i]
dt_box_num = total_dt_num[example_idx + i] dt_box_num = total_dt_num[example_idx + i]
overlaps.append( overlaps.append(
parted_overlaps[j][gt_num_idx:gt_num_idx + gt_box_num, parted_overlaps[j][dt_num_idx:dt_num_idx + dt_box_num,
dt_num_idx:dt_num_idx + dt_box_num]) gt_num_idx:gt_num_idx + gt_box_num])
gt_num_idx += gt_box_num gt_num_idx += gt_box_num
dt_num_idx += dt_box_num dt_num_idx += dt_box_num
example_idx += num_part example_idx += num_part
return overlaps, parted_overlaps, total_gt_num, total_dt_num return overlaps, parted_overlaps, total_dt_num, total_gt_num
def _prepare_data(gt_annos, dt_annos, current_class, difficulty): def _prepare_data(gt_annos, dt_annos, current_class, difficulty):
......
...@@ -39,7 +39,10 @@ class KittiDataset(Det3DDataset): ...@@ -39,7 +39,10 @@ class KittiDataset(Det3DDataset):
Default: [0, -40, -3, 70.4, 40, 0.0]. Default: [0, -40, -3, 70.4, 40, 0.0].
""" """
# TODO: use full classes of kitti # TODO: use full classes of kitti
METAINFO = {'CLASSES': ('Pedestrian', 'Cyclist', 'Car')} METAINFO = {
'CLASSES': ('Pedestrian', 'Cyclist', 'Car', 'Van', 'Truck',
'Person_sitting', 'Tram', 'Misc')
}
def __init__(self, def __init__(self,
data_root: str, data_root: str,
......
...@@ -64,13 +64,21 @@ class KittiMetric(BaseMetric): ...@@ -64,13 +64,21 @@ class KittiMetric(BaseMetric):
raise KeyError("metric should be one of 'bbox', 'img_bbox', " raise KeyError("metric should be one of 'bbox', 'img_bbox', "
'but got {metric}.') 'but got {metric}.')
def convert_annos_to_kitti_annos(self, data_annos: list, def convert_annos_to_kitti_annos(
classes: list) -> list: self,
data_annos: list,
classes: list = [
'Pedestrian', 'Cyclist', 'Car', 'Van', 'Truck', 'Person_sitting',
'Tram', 'Misc'
]
) -> list:
"""Convert loading annotations to Kitti annotations. """Convert loading annotations to Kitti annotations.
Args: Args:
data_annos (list[dict]): Annotations loaded from ann_file. data_annos (list[dict]): Annotations loaded from ann_file.
classes (list[str]): Classes used in the dataset. classes (list[str]): Classes used in the dataset. Default used
['Pedestrian', 'Cyclist', 'Car', 'Van', 'Truck',
'Person_sitting', 'Tram', 'Misc'].
Returns: Returns:
List[dict]: List of Kitti annotations. List[dict]: List of Kitti annotations.
...@@ -102,7 +110,10 @@ class KittiMetric(BaseMetric): ...@@ -102,7 +110,10 @@ class KittiMetric(BaseMetric):
'score': [] 'score': []
} }
for instance in annos['instances']: for instance in annos['instances']:
kitti_annos['name'].append(classes[instance['bbox_label']]) labels = instance['bbox_label']
if labels == -1:
continue
kitti_annos['name'].append(classes[labels])
kitti_annos['truncated'].append(instance['truncated']) kitti_annos['truncated'].append(instance['truncated'])
kitti_annos['occluded'].append(instance['occluded']) kitti_annos['occluded'].append(instance['occluded'])
kitti_annos['alpha'].append(instance['alpha']) kitti_annos['alpha'].append(instance['alpha'])
...@@ -148,7 +159,7 @@ class KittiMetric(BaseMetric): ...@@ -148,7 +159,7 @@ class KittiMetric(BaseMetric):
for pred_result in pred: for pred_result in pred:
for attr_name in pred[pred_result]: for attr_name in pred[pred_result]:
pred[pred_result][attr_name] = pred[pred_result][ pred[pred_result][attr_name] = pred[pred_result][
attr_name].to(self.collect_device) attr_name].to('cpu')
result[pred_result] = pred[pred_result] result[pred_result] = pred[pred_result]
sample_idx = data['data_sample']['sample_idx'] sample_idx = data['data_sample']['sample_idx']
result['sample_idx'] = sample_idx result['sample_idx'] = sample_idx
...@@ -165,14 +176,11 @@ class KittiMetric(BaseMetric): ...@@ -165,14 +176,11 @@ class KittiMetric(BaseMetric):
the metrics, and the values are corresponding results. the metrics, and the values are corresponding results.
""" """
logger: MMLogger = MMLogger.get_current_instance() logger: MMLogger = MMLogger.get_current_instance()
self.classes = self.dataset_meta['CLASSES'] self.classes = self.dataset_meta['CLASSES']
# load annotations # load annotations
pkl_annos = self.load_annotations(self.ann_file)['data_list'] pkl_annos = self.load_annotations(self.ann_file)['data_list']
self.data_infos = self.convert_annos_to_kitti_annos( self.data_infos = self.convert_annos_to_kitti_annos(pkl_annos)
pkl_annos, self.classes)
result_dict, tmp_dir = self.format_results( result_dict, tmp_dir = self.format_results(
results, results,
pklfile_prefix=self.pklfile_prefix, pklfile_prefix=self.pklfile_prefix,
...@@ -200,7 +208,7 @@ class KittiMetric(BaseMetric): ...@@ -200,7 +208,7 @@ class KittiMetric(BaseMetric):
return metric_dict return metric_dict
def kitti_evaluate(self, def kitti_evaluate(self,
result_dict: List[dict], results_dict: List[dict],
gt_annos: List[dict], gt_annos: List[dict],
metric: str = None, metric: str = None,
classes: List[str] = None, classes: List[str] = None,
...@@ -221,13 +229,13 @@ class KittiMetric(BaseMetric): ...@@ -221,13 +229,13 @@ class KittiMetric(BaseMetric):
dict[str, float]: Results of each evaluation metric. dict[str, float]: Results of each evaluation metric.
""" """
ap_dict = dict() ap_dict = dict()
for name in result_dict: for name in results_dict:
if name == 'pred_instances' or metric == 'img_bbox': if name == 'pred_instances' or metric == 'img_bbox':
eval_types = ['bbox'] eval_types = ['bbox']
else: else:
eval_types = ['bbox', 'bev', '3d'] eval_types = ['bbox', 'bev', '3d']
ap_result_str, ap_dict_ = kitti_eval( ap_result_str, ap_dict_ = kitti_eval(
gt_annos, result_dict[name], classes, eval_types=eval_types) gt_annos, results_dict[name], classes, eval_types=eval_types)
for ap_type, ap in ap_dict_.items(): for ap_type, ap in ap_dict_.items():
ap_dict[f'{name}/{ap_type}'] = float('{:.4f}'.format(ap)) ap_dict[f'{name}/{ap_type}'] = float('{:.4f}'.format(ap))
......
...@@ -212,7 +212,8 @@ def update_kitti_infos(pkl_path, out_dir): ...@@ -212,7 +212,8 @@ def update_kitti_infos(pkl_path, out_dir):
# TODO update to full label # TODO update to full label
# TODO discuss how to process 'Van', 'DontCare' # TODO discuss how to process 'Van', 'DontCare'
METAINFO = { METAINFO = {
'CLASSES': ('Pedestrian', 'Cyclist', 'Car'), 'CLASSES': ('Pedestrian', 'Cyclist', 'Car', 'Van', 'Truck',
'Person_sitting', 'Tram', 'Misc'),
} }
print(f'Reading from input file: {pkl_path}.') print(f'Reading from input file: {pkl_path}.')
data_list = mmcv.load(pkl_path) data_list = mmcv.load(pkl_path)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment