Commit 97f388fd authored by liyinhao's avatar liyinhao
Browse files

Can create kitti_gt

parent d59a0287
...@@ -147,6 +147,7 @@ def create_groundtruth_database(dataset_class_name, ...@@ -147,6 +147,7 @@ def create_groundtruth_database(dataset_class_name,
data_root=data_path, data_root=data_path,
ann_file=info_path, ann_file=info_path,
) )
file_client_args = dict(backend='disk')
if dataset_class_name == 'KittiDataset': if dataset_class_name == 'KittiDataset':
dataset_cfg.update( dataset_cfg.update(
test_mode=False, test_mode=False,
...@@ -156,7 +157,19 @@ def create_groundtruth_database(dataset_class_name, ...@@ -156,7 +157,19 @@ def create_groundtruth_database(dataset_class_name,
use_depth=False, use_depth=False,
use_lidar_intensity=True, use_lidar_intensity=True,
use_camera=with_mask, use_camera=with_mask,
)) ),
pipeline=[
dict(
type='LoadPointsFromFile',
load_dim=4,
use_dim=4,
file_client_args=file_client_args),
dict(
type='LoadAnnotations3D',
with_bbox_3d=True,
with_label_3d=True,
file_client_args=file_client_args)
])
dataset = build_dataset(dataset_cfg) dataset = build_dataset(dataset_cfg)
if database_save_path is None: if database_save_path is None:
...@@ -178,14 +191,15 @@ def create_groundtruth_database(dataset_class_name, ...@@ -178,14 +191,15 @@ def create_groundtruth_database(dataset_class_name,
group_counter = 0 group_counter = 0
for j in track_iter_progress(list(range(len(dataset)))): for j in track_iter_progress(list(range(len(dataset)))):
annos = dataset.get_data_info(j) input_dict = dataset.get_data_info(j)
image_idx = annos['sample_idx'] dataset.pre_pipeline(input_dict)
points = np.fromfile( example = dataset.pipeline(input_dict)
annos['pts_file_name'], dtype=np.float32).reshape(-1, 4) annos = example['ann_info']
gt_boxes_3d = annos['ann_info']['gt_bboxes_3d'] image_idx = example['sample_idx']
points = example['points']
gt_boxes_3d = annos['gt_bboxes_3d'].tensor.numpy()
names = annos['gt_names'] names = annos['gt_names']
group_dict = dict() group_dict = dict()
group_ids = np.full([gt_boxes_3d.shape[0]], -1, dtype=np.int64)
if 'group_ids' in annos: if 'group_ids' in annos:
group_ids = annos['group_ids'] group_ids = annos['group_ids']
else: else:
...@@ -200,7 +214,7 @@ def create_groundtruth_database(dataset_class_name, ...@@ -200,7 +214,7 @@ def create_groundtruth_database(dataset_class_name,
if with_mask: if with_mask:
# prepare masks # prepare masks
gt_boxes = annos['gt_bboxes'] gt_boxes = annos['gt_bboxes']
img_path = annos['filename'].split('/')[-1] img_path = osp.split(example['img_info']['filename'])[-1]
if img_path not in file2id.keys(): if img_path not in file2id.keys():
print('skip image {} for empty mask'.format(img_path)) print('skip image {} for empty mask'.format(img_path))
continue continue
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment