Commit a5d463d7 authored by VVsssssk's avatar VVsssssk Committed by ChaimZhu
Browse files

[Refactor]Refactor lyft

parent ed115937
......@@ -10,12 +10,8 @@ dataset_type = 'LyftDataset'
data_root = 'data/lyft/'
# Input modality for Lyft dataset, this is consistent with the submission
# format which requires the information in input_modality.
input_modality = dict(
use_lidar=True,
use_camera=False,
use_radar=False,
use_map=False,
use_external=False)
input_modality = dict(use_lidar=True, use_camera=False)
data_prefix = dict(pts='samples/LIDAR_TOP', img='')
file_client_args = dict(backend='disk')
# Uncomment the following if use ceph or other file clients.
# See https://mmcv.readthedocs.io/en/latest/api.html#mmcv.fileio.FileClient
......@@ -47,8 +43,9 @@ train_pipeline = [
dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range),
dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
dict(type='PointShuffle'),
dict(type='DefaultFormatBundle3D', class_names=class_names),
dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
dict(
type='Pack3DDetInputs',
keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
]
test_pipeline = [
dict(
......@@ -74,13 +71,9 @@ test_pipeline = [
translation_std=[0, 0, 0]),
dict(type='RandomFlip3D'),
dict(
type='PointsRangeFilter', point_cloud_range=point_cloud_range),
dict(
type='DefaultFormatBundle3D',
class_names=class_names,
with_label=False),
dict(type='Collect3D', keys=['points'])
])
type='PointsRangeFilter', point_cloud_range=point_cloud_range)
]),
dict(type='Pack3DDetInputs', keys=['points'])
]
# construct a pipeline for data and gt loading in show function
# please keep its loading function consistent with test_pipeline (e.g. client)
......@@ -95,42 +88,61 @@ eval_pipeline = [
type='LoadPointsFromMultiSweeps',
sweeps_num=10,
file_client_args=file_client_args),
dict(
type='DefaultFormatBundle3D',
class_names=class_names,
with_label=False),
dict(type='Collect3D', keys=['points'])
dict(type='Pack3DDetInputs', keys=['points'])
]
data = dict(
samples_per_gpu=2,
workers_per_gpu=2,
train=dict(
train_dataloader = dict(
batch_size=2,
num_workers=2,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file=data_root + 'lyft_infos_train.pkl',
ann_file='lyft_infos_train.pkl',
pipeline=train_pipeline,
classes=class_names,
metainfo=dict(CLASSES=class_names),
modality=input_modality,
test_mode=False),
val=dict(
data_prefix=data_prefix,
test_mode=False,
box_type_3d='LiDAR'))
test_dataloader = dict(
batch_size=1,
num_workers=1,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file=data_root + 'lyft_infos_val.pkl',
ann_file='lyft_infos_val.pkl',
pipeline=test_pipeline,
classes=class_names,
metainfo=dict(CLASSES=class_names),
modality=input_modality,
test_mode=True),
test=dict(
data_prefix=data_prefix,
test_mode=True,
box_type_3d='LiDAR'))
val_dataloader = dict(
batch_size=1,
num_workers=1,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file=data_root + 'lyft_infos_test.pkl',
ann_file='lyft_infos_val.pkl',
pipeline=test_pipeline,
classes=class_names,
metainfo=dict(CLASSES=class_names),
modality=input_modality,
test_mode=True))
# For Lyft dataset, we usually evaluate the model at the end of training.
# Since the models are trained by 24 epochs by default, we set evaluation
# interval to be 24. Please change the interval accordingly if you do not
# use a default schedule.
evaluation = dict(interval=24, pipeline=eval_pipeline)
test_mode=True,
data_prefix=data_prefix,
box_type_3d='LiDAR'))
val_evaluator = dict(
type='LyftMetric',
ann_file=data_root + 'lyft_infos_val.pkl',
metric='bbox')
test_evaluator = dict(
type='LyftMetric',
ann_file=data_root + 'lyft_infos_val.pkl',
metric='bbox')
......@@ -41,3 +41,9 @@ model = dict(
],
rotations=[0, 1.57],
reshape_out=True)))
# For Lyft dataset, we usually evaluate the model at the end of training.
# Since the models are trained by 24 epochs by default, we set evaluation
# interval to be 24. Please change the interval accordingly if you do not
# use a default schedule.
train_cfg = dict(val_interval=24)
This diff is collapsed.
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
from mmcv.transforms.base import BaseTransform
from mmengine.data import InstanceData
from mmengine.registry import TRANSFORMS
from mmdet3d.core.bbox import LiDARInstance3DBoxes
from mmdet3d.core.data_structures import Det3DDataSample
from mmdet3d.datasets import LyftDataset
def _generate_nus_dataset_config():
data_root = 'tests/data/lyft'
ann_file = 'lyft_infos.pkl'
classes = [
'car', 'truck', 'bus', 'emergency_vehicle', 'other_vehicle',
'motorcycle', 'bicycle', 'pedestrian', 'animal'
]
if 'Identity' not in TRANSFORMS:
@TRANSFORMS.register_module()
class Identity(BaseTransform):
def transform(self, info):
packed_input = dict(data_sample=Det3DDataSample())
if 'ann_info' in info:
packed_input['data_sample'].gt_instances_3d = InstanceData(
)
packed_input[
'data_sample'].gt_instances_3d.labels_3d = info[
'ann_info']['gt_labels_3d']
return packed_input
pipeline = [
dict(type='Identity'),
]
modality = dict(use_lidar=True, use_camera=False)
data_prefix = dict(pts='lidar', img='')
return data_root, ann_file, classes, data_prefix, pipeline, modality
def test_getitem():
np.random.seed(0)
data_root, ann_file, classes, data_prefix, pipeline, modality = \
_generate_nus_dataset_config()
lyft_dataset = LyftDataset(
data_root,
ann_file,
data_prefix=data_prefix,
pipeline=pipeline,
metainfo=dict(CLASSES=classes),
modality=modality)
lyft_dataset.prepare_data(0)
input_dict = lyft_dataset.get_data_info(0)
# assert the the path should contains data_prefix and data_root
assert input_dict['lidar_points'][
'lidar_path'] == 'tests/data/lyft/lidar/host-a017_lidar1_' \
'1236118886901125926.bin'
ann_info = lyft_dataset.parse_ann_info(input_dict)
# assert the keys in ann_info and the type
assert 'gt_labels_3d' in ann_info
assert ann_info['gt_labels_3d'].dtype == np.int64
assert len(ann_info['gt_labels_3d']) == 3
assert 'gt_bboxes_3d' in ann_info
assert isinstance(ann_info['gt_bboxes_3d'], LiDARInstance3DBoxes)
assert len(lyft_dataset.metainfo['CLASSES']) == 9
......@@ -561,6 +561,110 @@ def update_sunrgbd_infos(pkl_path, out_dir):
mmcv.dump(converted_data_info, out_path, 'pkl')
def update_lyft_infos(pkl_path, out_dir):
print(f'{pkl_path} will be modified.')
if out_dir in pkl_path:
print(f'Warning, you may overwriting '
f'the original data {pkl_path}.')
print(f'Reading from input file: {pkl_path}.')
data_list = mmcv.load(pkl_path)
METAINFO = {
'CLASSES':
('car', 'truck', 'bus', 'emergency_vehicle', 'other_vehicle',
'motorcycle', 'bicycle', 'pedestrian', 'animal'),
'DATASET':
'Nuscenes',
'version':
data_list['metadata']['version']
}
print('Start updating:')
converted_list = []
for i, ori_info_dict in enumerate(
mmcv.track_iter_progress(data_list['infos'])):
temp_data_info = get_empty_standard_data_info()
temp_data_info['sample_idx'] = i
temp_data_info['token'] = ori_info_dict['token']
temp_data_info['ego2global'] = convert_quaternion_to_matrix(
ori_info_dict['ego2global_rotation'],
ori_info_dict['ego2global_translation'])
temp_data_info['lidar_points']['lidar_path'] = ori_info_dict[
'lidar_path'].split('/')[-1]
temp_data_info['lidar_points'][
'lidar2ego'] = convert_quaternion_to_matrix(
ori_info_dict['lidar2ego_rotation'],
ori_info_dict['lidar2ego_translation'])
# bc-breaking: Timestamp has divided 1e6 in pkl infos.
temp_data_info['timestamp'] = ori_info_dict['timestamp'] / 1e6
for ori_sweep in ori_info_dict['sweeps']:
temp_lidar_sweep = get_single_lidar_sweep()
temp_lidar_sweep['lidar_points'][
'lidar2ego'] = convert_quaternion_to_matrix(
ori_sweep['sensor2ego_rotation'],
ori_sweep['sensor2ego_translation'])
temp_lidar_sweep['ego2global'] = convert_quaternion_to_matrix(
ori_sweep['ego2global_rotation'],
ori_sweep['ego2global_translation'])
lidar2sensor = np.eye(4)
lidar2sensor[:3, :3] = ori_sweep['sensor2lidar_rotation'].T
lidar2sensor[:3, 3] = -ori_sweep['sensor2lidar_translation']
temp_lidar_sweep['lidar_points'][
'lidar2sensor'] = lidar2sensor.astype(np.float32).tolist()
# bc-breaking: Timestamp has divided 1e6 in pkl infos.
temp_lidar_sweep['timestamp'] = ori_sweep['timestamp'] / 1e6
temp_lidar_sweep['lidar_points']['lidar_path'] = ori_sweep[
'data_path']
temp_lidar_sweep['sample_data_token'] = ori_sweep[
'sample_data_token']
temp_data_info['lidar_sweeps'].append(temp_lidar_sweep)
temp_data_info['images'] = {}
for cam in ori_info_dict['cams']:
empty_img_info = get_empty_img_info()
empty_img_info['img_path'] = ori_info_dict['cams'][cam][
'data_path'].split('/')[-1]
empty_img_info['cam2img'] = ori_info_dict['cams'][cam][
'cam_intrinsic'].tolist()
empty_img_info['sample_data_token'] = ori_info_dict['cams'][cam][
'sample_data_token']
empty_img_info[
'timestamp'] = ori_info_dict['cams'][cam]['timestamp'] / 1e6
empty_img_info['cam2ego'] = convert_quaternion_to_matrix(
ori_info_dict['cams'][cam]['sensor2ego_rotation'],
ori_info_dict['cams'][cam]['sensor2ego_translation'])
lidar2sensor = np.eye(4)
lidar2sensor[:3, :3] = ori_info_dict['cams'][cam][
'sensor2lidar_rotation'].T
lidar2sensor[:3, 3] = -ori_info_dict['cams'][cam][
'sensor2lidar_translation']
empty_img_info['lidar2cam'] = lidar2sensor.astype(
np.float32).tolist()
temp_data_info['images'][cam] = empty_img_info
num_instances = ori_info_dict['gt_boxes'].shape[0]
ignore_class_name = set()
for i in range(num_instances):
empty_instance = get_empty_instance()
empty_instance['bbox_3d'] = ori_info_dict['gt_boxes'][
i, :].tolist()
if ori_info_dict['gt_names'][i] in METAINFO['CLASSES']:
empty_instance['bbox_label'] = METAINFO['CLASSES'].index(
ori_info_dict['gt_names'][i])
else:
ignore_class_name.add(ori_info_dict['gt_names'][i])
empty_instance['bbox_label'] = -1
empty_instance['bbox_label_3d'] = copy.deepcopy(
empty_instance['bbox_label'])
empty_instance = clear_instance_unused_keys(empty_instance)
temp_data_info['instances'].append(empty_instance)
temp_data_info, _ = clear_data_info_unused_keys(temp_data_info)
converted_list.append(temp_data_info)
pkl_name = pkl_path.split('/')[-1]
out_path = osp.join(out_dir, pkl_name)
print(f'Writing to output file: {out_path}.')
print(f'ignore classes: {ignore_class_name}')
converted_data_info = dict(metainfo=METAINFO, data_list=converted_list)
mmcv.dump(converted_data_info, out_path, 'pkl')
def parse_args():
parser = argparse.ArgumentParser(description='Arg parser for data coords '
'update due to coords sys refactor.')
......@@ -592,6 +696,8 @@ def main():
update_scannet_infos(pkl_path=args.pkl, out_dir=args.out_dir)
elif args.dataset.lower() == 'sunrgbd':
update_sunrgbd_infos(pkl_path=args.pkl, out_dir=args.out_dir)
elif args.dataset.lower() == 'lyft':
update_lyft_infos(pkl_path=args.pkl, out_dir=args.out_dir)
elif args.dataset.lower() == 'nuscenes':
update_nuscenes_infos(pkl_path=args.pkl, out_dir=args.out_dir)
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment