"...windows/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "7839bdbe1389e734b00529edd9a7566bb8701588"
Commit 660f3ccc authored by zhangwenwei's avatar zhangwenwei
Browse files

Merge branch 'fix_gt_database' into 'master'

Fix gt database

See merge request open-mmlab/mmdet.3d!90
parents d7ade147 dff4b320
...@@ -116,7 +116,8 @@ class KittiDataset(Custom3DDataset): ...@@ -116,7 +116,8 @@ class KittiDataset(Custom3DDataset):
gt_bboxes_3d=gt_bboxes_3d, gt_bboxes_3d=gt_bboxes_3d,
gt_labels_3d=gt_labels_3d, gt_labels_3d=gt_labels_3d,
bboxes=gt_bboxes, bboxes=gt_bboxes,
labels=gt_labels) labels=gt_labels,
gt_names=gt_names)
return anns_results return anns_results
def drop_arrays_by_name(self, gt_names, used_classes): def drop_arrays_by_name(self, gt_names, used_classes):
......
...@@ -182,7 +182,7 @@ class NuScenesDataset(Custom3DDataset): ...@@ -182,7 +182,7 @@ class NuScenesDataset(Custom3DDataset):
anns_results = dict( anns_results = dict(
gt_bboxes_3d=gt_bboxes_3d, gt_bboxes_3d=gt_bboxes_3d,
gt_labels_3d=gt_labels_3d, gt_labels_3d=gt_labels_3d,
) gt_names=gt_names_3d)
return anns_results return anns_results
def _format_bbox(self, results, jsonfile_prefix=None): def _format_bbox(self, results, jsonfile_prefix=None):
......
...@@ -144,27 +144,52 @@ def create_groundtruth_database(dataset_class_name, ...@@ -144,27 +144,52 @@ def create_groundtruth_database(dataset_class_name,
print(f'Create GT Database of {dataset_class_name}') print(f'Create GT Database of {dataset_class_name}')
dataset_cfg = dict( dataset_cfg = dict(
type=dataset_class_name, type=dataset_class_name,
root_path=data_path, data_root=data_path,
ann_file=info_path, ann_file=info_path,
) )
if dataset_class_name == 'KittiDataset': if dataset_class_name == 'KittiDataset':
file_client_args = dict(backend='disk')
dataset_cfg.update( dataset_cfg.update(
training=True, test_mode=False,
split='training', split='training',
modality=dict( modality=dict(
use_lidar=True, use_lidar=True,
use_depth=False, use_depth=False,
use_lidar_intensity=True, use_lidar_intensity=True,
use_camera=with_mask, use_camera=with_mask,
)) ),
pipeline=[
dict(
type='LoadPointsFromFile',
load_dim=4,
use_dim=4,
file_client_args=file_client_args),
dict(
type='LoadAnnotations3D',
with_bbox_3d=True,
with_label_3d=True,
file_client_args=file_client_args)
])
elif dataset_class_name == 'NuScenesDataset':
dataset_cfg.update(pipeline=[
dict(type='LoadPointsFromFile', load_dim=5, use_dim=5),
dict(
type='LoadPointsFromMultiSweeps',
sweeps_num=10,
),
dict(
type='LoadAnnotations3D',
with_bbox_3d=True,
with_label_3d=True)
])
dataset = build_dataset(dataset_cfg) dataset = build_dataset(dataset_cfg)
if database_save_path is None: if database_save_path is None:
database_save_path = osp.join(data_path, database_save_path = osp.join(data_path, f'{info_prefix}_gt_database')
'{}_gt_database'.format(info_prefix))
if db_info_save_path is None: if db_info_save_path is None:
db_info_save_path = osp.join( db_info_save_path = osp.join(data_path,
data_path, '{}_dbinfos_train.pkl'.format(info_prefix)) f'{info_prefix}_dbinfos_train.pkl')
mmcv.mkdir_or_exist(database_save_path) mmcv.mkdir_or_exist(database_save_path)
all_db_infos = dict() all_db_infos = dict()
...@@ -178,14 +203,15 @@ def create_groundtruth_database(dataset_class_name, ...@@ -178,14 +203,15 @@ def create_groundtruth_database(dataset_class_name,
group_counter = 0 group_counter = 0
for j in track_iter_progress(list(range(len(dataset)))): for j in track_iter_progress(list(range(len(dataset)))):
image_idx = j input_dict = dataset.get_data_info(j)
annos = dataset.get_sensor_data(j) dataset.pre_pipeline(input_dict)
image_idx = annos['sample_idx'] example = dataset.pipeline(input_dict)
points = annos['points'] annos = example['ann_info']
gt_boxes_3d = annos['gt_bboxes_3d'] image_idx = example['sample_idx']
points = example['points']
gt_boxes_3d = annos['gt_bboxes_3d'].tensor.numpy()
names = annos['gt_names'] names = annos['gt_names']
group_dict = dict() group_dict = dict()
group_ids = np.full([gt_boxes_3d.shape[0]], -1, dtype=np.int64)
if 'group_ids' in annos: if 'group_ids' in annos:
group_ids = annos['group_ids'] group_ids = annos['group_ids']
else: else:
...@@ -200,9 +226,9 @@ def create_groundtruth_database(dataset_class_name, ...@@ -200,9 +226,9 @@ def create_groundtruth_database(dataset_class_name,
if with_mask: if with_mask:
# prepare masks # prepare masks
gt_boxes = annos['gt_bboxes'] gt_boxes = annos['gt_bboxes']
img_path = annos['filename'].split('/')[-1] img_path = osp.split(example['img_info']['filename'])[-1]
if img_path not in file2id.keys(): if img_path not in file2id.keys():
print('skip image {} for empty mask'.format(img_path)) print(f'skip image {img_path} for empty mask')
continue continue
img_id = file2id[img_path] img_id = file2id[img_path]
kins_annIds = coco.getAnnIds(imgIds=img_id) kins_annIds = coco.getAnnIds(imgIds=img_id)
...@@ -230,7 +256,8 @@ def create_groundtruth_database(dataset_class_name, ...@@ -230,7 +256,8 @@ def create_groundtruth_database(dataset_class_name,
for i in range(num_obj): for i in range(num_obj):
filename = f'{image_idx}_{names[i]}_{i}.bin' filename = f'{image_idx}_{names[i]}_{i}.bin'
filepath = osp.join(database_save_path, filename) abs_filepath = osp.join(database_save_path, filename)
rel_filepath = osp.join(f'{info_prefix}_gt_database', filename)
# save point clouds and image patches for each object # save point clouds and image patches for each object
gt_points = points[point_indices[:, i]] gt_points = points[point_indices[:, i]]
...@@ -240,22 +267,18 @@ def create_groundtruth_database(dataset_class_name, ...@@ -240,22 +267,18 @@ def create_groundtruth_database(dataset_class_name,
if object_masks[i].sum() == 0 or not valid_inds[i]: if object_masks[i].sum() == 0 or not valid_inds[i]:
# Skip object for empty or invalid mask # Skip object for empty or invalid mask
continue continue
img_patch_path = filepath + '.png' img_patch_path = abs_filepath + '.png'
mask_patch_path = filepath + '.mask.png' mask_patch_path = abs_filepath + '.mask.png'
mmcv.imwrite(object_img_patches[i], img_patch_path) mmcv.imwrite(object_img_patches[i], img_patch_path)
mmcv.imwrite(object_masks[i], mask_patch_path) mmcv.imwrite(object_masks[i], mask_patch_path)
with open(filepath, 'w') as f: with open(abs_filepath, 'w') as f:
gt_points.tofile(f) gt_points.tofile(f)
if (used_classes is None) or names[i] in used_classes: if (used_classes is None) or names[i] in used_classes:
if relative_path:
db_path = osp.join(data_path, filename)
else:
db_path = filepath
db_info = { db_info = {
'name': names[i], 'name': names[i],
'path': db_path, 'path': rel_filepath,
'image_idx': image_idx, 'image_idx': image_idx,
'gt_idx': i, 'gt_idx': i,
'box3d_lidar': gt_boxes_3d[i], 'box3d_lidar': gt_boxes_3d[i],
......
...@@ -92,9 +92,8 @@ def create_kitti_info_file(data_path, ...@@ -92,9 +92,8 @@ def create_kitti_info_file(data_path,
relative_path (bool): Whether to use relative path. relative_path (bool): Whether to use relative path.
""" """
imageset_folder = Path(data_path) / 'ImageSets' imageset_folder = Path(data_path) / 'ImageSets'
train_img_ids = _read_imageset_file( train_img_ids = _read_imageset_file(str(imageset_folder / 'train.txt'))
str(imageset_folder / 'train_6014.txt')) val_img_ids = _read_imageset_file(str(imageset_folder / 'val.txt'))
val_img_ids = _read_imageset_file(str(imageset_folder / 'val_1467.txt'))
test_img_ids = _read_imageset_file(str(imageset_folder / 'test.txt')) test_img_ids = _read_imageset_file(str(imageset_folder / 'test.txt'))
print('Generate info. this may take several minutes.') print('Generate info. this may take several minutes.')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment