Commit a5d463d7 authored by VVsssssk's avatar VVsssssk Committed by ChaimZhu
Browse files

[Refactor]Refactor lyft

parent ed115937
...@@ -10,12 +10,8 @@ dataset_type = 'LyftDataset' ...@@ -10,12 +10,8 @@ dataset_type = 'LyftDataset'
data_root = 'data/lyft/' data_root = 'data/lyft/'
# Input modality for Lyft dataset, this is consistent with the submission # Input modality for Lyft dataset, this is consistent with the submission
# format which requires the information in input_modality. # format which requires the information in input_modality.
input_modality = dict( input_modality = dict(use_lidar=True, use_camera=False)
use_lidar=True, data_prefix = dict(pts='samples/LIDAR_TOP', img='')
use_camera=False,
use_radar=False,
use_map=False,
use_external=False)
file_client_args = dict(backend='disk') file_client_args = dict(backend='disk')
# Uncomment the following if use ceph or other file clients. # Uncomment the following if use ceph or other file clients.
# See https://mmcv.readthedocs.io/en/latest/api.html#mmcv.fileio.FileClient # See https://mmcv.readthedocs.io/en/latest/api.html#mmcv.fileio.FileClient
...@@ -47,8 +43,9 @@ train_pipeline = [ ...@@ -47,8 +43,9 @@ train_pipeline = [
dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range), dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range),
dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range), dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
dict(type='PointShuffle'), dict(type='PointShuffle'),
dict(type='DefaultFormatBundle3D', class_names=class_names), dict(
dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d']) type='Pack3DDetInputs',
keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
] ]
test_pipeline = [ test_pipeline = [
dict( dict(
...@@ -74,13 +71,9 @@ test_pipeline = [ ...@@ -74,13 +71,9 @@ test_pipeline = [
translation_std=[0, 0, 0]), translation_std=[0, 0, 0]),
dict(type='RandomFlip3D'), dict(type='RandomFlip3D'),
dict( dict(
type='PointsRangeFilter', point_cloud_range=point_cloud_range), type='PointsRangeFilter', point_cloud_range=point_cloud_range)
dict( ]),
type='DefaultFormatBundle3D', dict(type='Pack3DDetInputs', keys=['points'])
class_names=class_names,
with_label=False),
dict(type='Collect3D', keys=['points'])
])
] ]
# construct a pipeline for data and gt loading in show function # construct a pipeline for data and gt loading in show function
# please keep its loading function consistent with test_pipeline (e.g. client) # please keep its loading function consistent with test_pipeline (e.g. client)
...@@ -95,42 +88,61 @@ eval_pipeline = [ ...@@ -95,42 +88,61 @@ eval_pipeline = [
type='LoadPointsFromMultiSweeps', type='LoadPointsFromMultiSweeps',
sweeps_num=10, sweeps_num=10,
file_client_args=file_client_args), file_client_args=file_client_args),
dict( dict(type='Pack3DDetInputs', keys=['points'])
type='DefaultFormatBundle3D',
class_names=class_names,
with_label=False),
dict(type='Collect3D', keys=['points'])
] ]
train_dataloader = dict(
data = dict( batch_size=2,
samples_per_gpu=2, num_workers=2,
workers_per_gpu=2, persistent_workers=True,
train=dict( sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type=dataset_type, type=dataset_type,
data_root=data_root, data_root=data_root,
ann_file=data_root + 'lyft_infos_train.pkl', ann_file='lyft_infos_train.pkl',
pipeline=train_pipeline, pipeline=train_pipeline,
classes=class_names, metainfo=dict(CLASSES=class_names),
modality=input_modality, modality=input_modality,
test_mode=False), data_prefix=data_prefix,
val=dict( test_mode=False,
box_type_3d='LiDAR'))
test_dataloader = dict(
batch_size=1,
num_workers=1,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type, type=dataset_type,
data_root=data_root, data_root=data_root,
ann_file=data_root + 'lyft_infos_val.pkl', ann_file='lyft_infos_val.pkl',
pipeline=test_pipeline, pipeline=test_pipeline,
classes=class_names, metainfo=dict(CLASSES=class_names),
modality=input_modality, modality=input_modality,
test_mode=True), data_prefix=data_prefix,
test=dict( test_mode=True,
box_type_3d='LiDAR'))
val_dataloader = dict(
batch_size=1,
num_workers=1,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type, type=dataset_type,
data_root=data_root, data_root=data_root,
ann_file=data_root + 'lyft_infos_test.pkl', ann_file='lyft_infos_val.pkl',
pipeline=test_pipeline, pipeline=test_pipeline,
classes=class_names, metainfo=dict(CLASSES=class_names),
modality=input_modality, modality=input_modality,
test_mode=True)) test_mode=True,
# For Lyft dataset, we usually evaluate the model at the end of training. data_prefix=data_prefix,
# Since the models are trained by 24 epochs by default, we set evaluation box_type_3d='LiDAR'))
# interval to be 24. Please change the interval accordingly if you do not
# use a default schedule. val_evaluator = dict(
evaluation = dict(interval=24, pipeline=eval_pipeline) type='LyftMetric',
ann_file=data_root + 'lyft_infos_val.pkl',
metric='bbox')
test_evaluator = dict(
type='LyftMetric',
ann_file=data_root + 'lyft_infos_val.pkl',
metric='bbox')
...@@ -41,3 +41,9 @@ model = dict( ...@@ -41,3 +41,9 @@ model = dict(
], ],
rotations=[0, 1.57], rotations=[0, 1.57],
reshape_out=True))) reshape_out=True)))
# For Lyft dataset, we usually evaluate the model at the end of training.
# Since the models are trained by 24 epochs by default, we set evaluation
# interval to be 24. Please change the interval accordingly if you do not
# use a default schedule.
train_cfg = dict(val_interval=24)
This diff is collapsed.
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
from mmcv.transforms.base import BaseTransform
from mmengine.data import InstanceData
from mmengine.registry import TRANSFORMS
from mmdet3d.core.bbox import LiDARInstance3DBoxes
from mmdet3d.core.data_structures import Det3DDataSample
from mmdet3d.datasets import LyftDataset
def _generate_nus_dataset_config():
data_root = 'tests/data/lyft'
ann_file = 'lyft_infos.pkl'
classes = [
'car', 'truck', 'bus', 'emergency_vehicle', 'other_vehicle',
'motorcycle', 'bicycle', 'pedestrian', 'animal'
]
if 'Identity' not in TRANSFORMS:
@TRANSFORMS.register_module()
class Identity(BaseTransform):
def transform(self, info):
packed_input = dict(data_sample=Det3DDataSample())
if 'ann_info' in info:
packed_input['data_sample'].gt_instances_3d = InstanceData(
)
packed_input[
'data_sample'].gt_instances_3d.labels_3d = info[
'ann_info']['gt_labels_3d']
return packed_input
pipeline = [
dict(type='Identity'),
]
modality = dict(use_lidar=True, use_camera=False)
data_prefix = dict(pts='lidar', img='')
return data_root, ann_file, classes, data_prefix, pipeline, modality
def test_getitem():
np.random.seed(0)
data_root, ann_file, classes, data_prefix, pipeline, modality = \
_generate_nus_dataset_config()
lyft_dataset = LyftDataset(
data_root,
ann_file,
data_prefix=data_prefix,
pipeline=pipeline,
metainfo=dict(CLASSES=classes),
modality=modality)
lyft_dataset.prepare_data(0)
input_dict = lyft_dataset.get_data_info(0)
# assert the the path should contains data_prefix and data_root
assert input_dict['lidar_points'][
'lidar_path'] == 'tests/data/lyft/lidar/host-a017_lidar1_' \
'1236118886901125926.bin'
ann_info = lyft_dataset.parse_ann_info(input_dict)
# assert the keys in ann_info and the type
assert 'gt_labels_3d' in ann_info
assert ann_info['gt_labels_3d'].dtype == np.int64
assert len(ann_info['gt_labels_3d']) == 3
assert 'gt_bboxes_3d' in ann_info
assert isinstance(ann_info['gt_bboxes_3d'], LiDARInstance3DBoxes)
assert len(lyft_dataset.metainfo['CLASSES']) == 9
...@@ -561,6 +561,110 @@ def update_sunrgbd_infos(pkl_path, out_dir): ...@@ -561,6 +561,110 @@ def update_sunrgbd_infos(pkl_path, out_dir):
mmcv.dump(converted_data_info, out_path, 'pkl') mmcv.dump(converted_data_info, out_path, 'pkl')
def update_lyft_infos(pkl_path, out_dir):
print(f'{pkl_path} will be modified.')
if out_dir in pkl_path:
print(f'Warning, you may overwriting '
f'the original data {pkl_path}.')
print(f'Reading from input file: {pkl_path}.')
data_list = mmcv.load(pkl_path)
METAINFO = {
'CLASSES':
('car', 'truck', 'bus', 'emergency_vehicle', 'other_vehicle',
'motorcycle', 'bicycle', 'pedestrian', 'animal'),
'DATASET':
'Nuscenes',
'version':
data_list['metadata']['version']
}
print('Start updating:')
converted_list = []
for i, ori_info_dict in enumerate(
mmcv.track_iter_progress(data_list['infos'])):
temp_data_info = get_empty_standard_data_info()
temp_data_info['sample_idx'] = i
temp_data_info['token'] = ori_info_dict['token']
temp_data_info['ego2global'] = convert_quaternion_to_matrix(
ori_info_dict['ego2global_rotation'],
ori_info_dict['ego2global_translation'])
temp_data_info['lidar_points']['lidar_path'] = ori_info_dict[
'lidar_path'].split('/')[-1]
temp_data_info['lidar_points'][
'lidar2ego'] = convert_quaternion_to_matrix(
ori_info_dict['lidar2ego_rotation'],
ori_info_dict['lidar2ego_translation'])
# bc-breaking: Timestamp has divided 1e6 in pkl infos.
temp_data_info['timestamp'] = ori_info_dict['timestamp'] / 1e6
for ori_sweep in ori_info_dict['sweeps']:
temp_lidar_sweep = get_single_lidar_sweep()
temp_lidar_sweep['lidar_points'][
'lidar2ego'] = convert_quaternion_to_matrix(
ori_sweep['sensor2ego_rotation'],
ori_sweep['sensor2ego_translation'])
temp_lidar_sweep['ego2global'] = convert_quaternion_to_matrix(
ori_sweep['ego2global_rotation'],
ori_sweep['ego2global_translation'])
lidar2sensor = np.eye(4)
lidar2sensor[:3, :3] = ori_sweep['sensor2lidar_rotation'].T
lidar2sensor[:3, 3] = -ori_sweep['sensor2lidar_translation']
temp_lidar_sweep['lidar_points'][
'lidar2sensor'] = lidar2sensor.astype(np.float32).tolist()
# bc-breaking: Timestamp has divided 1e6 in pkl infos.
temp_lidar_sweep['timestamp'] = ori_sweep['timestamp'] / 1e6
temp_lidar_sweep['lidar_points']['lidar_path'] = ori_sweep[
'data_path']
temp_lidar_sweep['sample_data_token'] = ori_sweep[
'sample_data_token']
temp_data_info['lidar_sweeps'].append(temp_lidar_sweep)
temp_data_info['images'] = {}
for cam in ori_info_dict['cams']:
empty_img_info = get_empty_img_info()
empty_img_info['img_path'] = ori_info_dict['cams'][cam][
'data_path'].split('/')[-1]
empty_img_info['cam2img'] = ori_info_dict['cams'][cam][
'cam_intrinsic'].tolist()
empty_img_info['sample_data_token'] = ori_info_dict['cams'][cam][
'sample_data_token']
empty_img_info[
'timestamp'] = ori_info_dict['cams'][cam]['timestamp'] / 1e6
empty_img_info['cam2ego'] = convert_quaternion_to_matrix(
ori_info_dict['cams'][cam]['sensor2ego_rotation'],
ori_info_dict['cams'][cam]['sensor2ego_translation'])
lidar2sensor = np.eye(4)
lidar2sensor[:3, :3] = ori_info_dict['cams'][cam][
'sensor2lidar_rotation'].T
lidar2sensor[:3, 3] = -ori_info_dict['cams'][cam][
'sensor2lidar_translation']
empty_img_info['lidar2cam'] = lidar2sensor.astype(
np.float32).tolist()
temp_data_info['images'][cam] = empty_img_info
num_instances = ori_info_dict['gt_boxes'].shape[0]
ignore_class_name = set()
for i in range(num_instances):
empty_instance = get_empty_instance()
empty_instance['bbox_3d'] = ori_info_dict['gt_boxes'][
i, :].tolist()
if ori_info_dict['gt_names'][i] in METAINFO['CLASSES']:
empty_instance['bbox_label'] = METAINFO['CLASSES'].index(
ori_info_dict['gt_names'][i])
else:
ignore_class_name.add(ori_info_dict['gt_names'][i])
empty_instance['bbox_label'] = -1
empty_instance['bbox_label_3d'] = copy.deepcopy(
empty_instance['bbox_label'])
empty_instance = clear_instance_unused_keys(empty_instance)
temp_data_info['instances'].append(empty_instance)
temp_data_info, _ = clear_data_info_unused_keys(temp_data_info)
converted_list.append(temp_data_info)
pkl_name = pkl_path.split('/')[-1]
out_path = osp.join(out_dir, pkl_name)
print(f'Writing to output file: {out_path}.')
print(f'ignore classes: {ignore_class_name}')
converted_data_info = dict(metainfo=METAINFO, data_list=converted_list)
mmcv.dump(converted_data_info, out_path, 'pkl')
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description='Arg parser for data coords ' parser = argparse.ArgumentParser(description='Arg parser for data coords '
'update due to coords sys refactor.') 'update due to coords sys refactor.')
...@@ -592,6 +696,8 @@ def main(): ...@@ -592,6 +696,8 @@ def main():
update_scannet_infos(pkl_path=args.pkl, out_dir=args.out_dir) update_scannet_infos(pkl_path=args.pkl, out_dir=args.out_dir)
elif args.dataset.lower() == 'sunrgbd': elif args.dataset.lower() == 'sunrgbd':
update_sunrgbd_infos(pkl_path=args.pkl, out_dir=args.out_dir) update_sunrgbd_infos(pkl_path=args.pkl, out_dir=args.out_dir)
elif args.dataset.lower() == 'lyft':
update_lyft_infos(pkl_path=args.pkl, out_dir=args.out_dir)
elif args.dataset.lower() == 'nuscenes': elif args.dataset.lower() == 'nuscenes':
update_nuscenes_infos(pkl_path=args.pkl, out_dir=args.out_dir) update_nuscenes_infos(pkl_path=args.pkl, out_dir=args.out_dir)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment