Commit bd73d3b9 authored by jshilong's avatar jshilong Committed by ChaimZhu
Browse files

[refactor]MVXTwoStage & Centerpoint

parent 360c27f9
voxel_size = [0.1, 0.1, 0.2] voxel_size = [0.1, 0.1, 0.2]
model = dict( model = dict(
type='CenterPoint', type='CenterPoint',
data_preprocessor=dict(type='Det3DDataPreprocessor'),
pts_voxel_layer=dict( pts_voxel_layer=dict(
max_num_points=10, voxel_size=voxel_size, max_voxels=(90000, 120000)), max_num_points=10, voxel_size=voxel_size, max_voxels=(90000, 120000)),
pts_voxel_encoder=dict(type='HardSimpleVFE', num_features=5), pts_voxel_encoder=dict(type='HardSimpleVFE', num_features=5),
...@@ -54,8 +55,9 @@ model = dict( ...@@ -54,8 +55,9 @@ model = dict(
code_size=9), code_size=9),
separate_head=dict( separate_head=dict(
type='SeparateHead', init_bias=-2.19, final_kernel=3), type='SeparateHead', init_bias=-2.19, final_kernel=3),
loss_cls=dict(type='GaussianFocalLoss', reduction='mean'), loss_cls=dict(type='mmdet.GaussianFocalLoss', reduction='mean'),
loss_bbox=dict(type='L1Loss', reduction='mean', loss_weight=0.25), loss_bbox=dict(
type='mmdet.L1Loss', reduction='mean', loss_weight=0.25),
norm_bbox=True), norm_bbox=True),
# model training and testing settings # model training and testing settings
train_cfg=dict( train_cfg=dict(
......
voxel_size = [0.2, 0.2, 8] voxel_size = [0.2, 0.2, 8]
model = dict( model = dict(
type='CenterPoint', type='CenterPoint',
data_preprocessor=dict(type='Det3DDataPreprocessor'),
pts_voxel_layer=dict( pts_voxel_layer=dict(
max_num_points=20, voxel_size=voxel_size, max_voxels=(30000, 40000)), max_num_points=20, voxel_size=voxel_size, max_voxels=(30000, 40000)),
pts_voxel_encoder=dict( pts_voxel_encoder=dict(
...@@ -53,8 +54,9 @@ model = dict( ...@@ -53,8 +54,9 @@ model = dict(
code_size=9), code_size=9),
separate_head=dict( separate_head=dict(
type='SeparateHead', init_bias=-2.19, final_kernel=3), type='SeparateHead', init_bias=-2.19, final_kernel=3),
loss_cls=dict(type='GaussianFocalLoss', reduction='mean'), loss_cls=dict(type='mmdet.GaussianFocalLoss', reduction='mean'),
loss_bbox=dict(type='L1Loss', reduction='mean', loss_weight=0.25), loss_bbox=dict(
type='mmdet.L1Loss', reduction='mean', loss_weight=0.25),
norm_bbox=True), norm_bbox=True),
# model training and testing settings # model training and testing settings
train_cfg=dict( train_cfg=dict(
......
# This schedule is mainly used by models with dynamic voxelization # This schedule is mainly used by models with dynamic voxelization
# optimizer # optimizer
lr = 0.003 # max learning rate lr = 0.003 # max learning rate
optimizer = dict( optim_wrapper = dict(
type='AdamW', type='OptimWrapper',
lr=lr, optimizer=dict(
betas=(0.95, 0.99), # the momentum is change during training type='AdamW', lr=lr, weight_decay=0.001, betas=(0.95, 0.99)),
weight_decay=0.001) clip_grad=dict(max_norm=10, norm_type=2),
optimizer_config = dict(grad_clip=dict(max_norm=10, norm_type=2)) )
lr_config = dict( param_scheduler = [
policy='CosineAnnealing', dict(type='LinearLR', start_factor=0.1, by_epoch=False, begin=0, end=1000),
warmup='linear', dict(
warmup_iters=1000, type='CosineAnnealingLR',
warmup_ratio=1.0 / 10, begin=0,
min_lr_ratio=1e-5) T_max=40,
end=40,
momentum_config = None by_epoch=True,
eta_min=1e-5)
runner = dict(type='EpochBasedRunner', max_epochs=40) ]
# training schedule for 1x
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=40, val_interval=1)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
...@@ -9,7 +9,7 @@ class_names = [ ...@@ -9,7 +9,7 @@ class_names = [
'car', 'truck', 'construction_vehicle', 'bus', 'trailer', 'barrier', 'car', 'truck', 'construction_vehicle', 'bus', 'trailer', 'barrier',
'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone' 'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone'
] ]
data_prefix = dict(pts='samples/LIDAR_TOP', img='')
model = dict( model = dict(
pts_voxel_layer=dict( pts_voxel_layer=dict(
voxel_size=voxel_size, point_cloud_range=point_cloud_range), voxel_size=voxel_size, point_cloud_range=point_cloud_range),
...@@ -96,7 +96,9 @@ train_pipeline = [ ...@@ -96,7 +96,9 @@ train_pipeline = [
dict(type='ObjectNameFilter', classes=class_names), dict(type='ObjectNameFilter', classes=class_names),
dict(type='PointShuffle'), dict(type='PointShuffle'),
dict(type='DefaultFormatBundle3D', class_names=class_names), dict(type='DefaultFormatBundle3D', class_names=class_names),
dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d']) dict(
type='Pack3DDetInputs',
keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
] ]
test_pipeline = [ test_pipeline = [
dict( dict(
...@@ -125,16 +127,15 @@ test_pipeline = [ ...@@ -125,16 +127,15 @@ test_pipeline = [
translation_std=[0, 0, 0]), translation_std=[0, 0, 0]),
dict(type='RandomFlip3D'), dict(type='RandomFlip3D'),
dict( dict(
type='PointsRangeFilter', point_cloud_range=point_cloud_range), type='PointsRangeFilter', point_cloud_range=point_cloud_range)
dict( ]),
type='DefaultFormatBundle3D', dict(type='Pack3DDetInputs', keys=['points'])
class_names=class_names,
with_label=False),
dict(type='Collect3D', keys=['points'])
])
] ]
train_dataloader = dict(
data = dict( dataset=dict(
train=dict(dataset=dict(pipeline=train_pipeline)), dataset=dict(
val=dict(pipeline=test_pipeline), pipeline=train_pipeline, metainfo=dict(CLASSES=class_names))))
test=dict(pipeline=test_pipeline)) test_dataloader = dict(
dataset=dict(pipeline=test_pipeline, metainfo=dict(CLASSES=class_names)))
val_dataloader = dict(
dataset=dict(pipeline=test_pipeline, metainfo=dict(CLASSES=class_names)))
...@@ -12,7 +12,7 @@ class_names = [ ...@@ -12,7 +12,7 @@ class_names = [
'car', 'truck', 'construction_vehicle', 'bus', 'trailer', 'barrier', 'car', 'truck', 'construction_vehicle', 'bus', 'trailer', 'barrier',
'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone' 'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone'
] ]
data_prefix = dict(pts='samples/LIDAR_TOP', img='')
model = dict( model = dict(
pts_voxel_layer=dict(point_cloud_range=point_cloud_range), pts_voxel_layer=dict(point_cloud_range=point_cloud_range),
pts_bbox_head=dict(bbox_coder=dict(pc_range=point_cloud_range[:2])), pts_bbox_head=dict(bbox_coder=dict(pc_range=point_cloud_range[:2])),
...@@ -90,8 +90,9 @@ train_pipeline = [ ...@@ -90,8 +90,9 @@ train_pipeline = [
dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range), dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
dict(type='ObjectNameFilter', classes=class_names), dict(type='ObjectNameFilter', classes=class_names),
dict(type='PointShuffle'), dict(type='PointShuffle'),
dict(type='DefaultFormatBundle3D', class_names=class_names), dict(
dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d']) type='Pack3DDetInputs',
keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
] ]
test_pipeline = [ test_pipeline = [
dict( dict(
...@@ -120,13 +121,9 @@ test_pipeline = [ ...@@ -120,13 +121,9 @@ test_pipeline = [
translation_std=[0, 0, 0]), translation_std=[0, 0, 0]),
dict(type='RandomFlip3D'), dict(type='RandomFlip3D'),
dict( dict(
type='PointsRangeFilter', point_cloud_range=point_cloud_range), type='PointsRangeFilter', point_cloud_range=point_cloud_range)
dict( ]),
type='DefaultFormatBundle3D', dict(type='Pack3DDetInputs', keys=['points'])
class_names=class_names,
with_label=False),
dict(type='Collect3D', keys=['points'])
])
] ]
# construct a pipeline for data and gt loading in show function # construct a pipeline for data and gt loading in show function
# please keep its loading function consistent with test_pipeline (e.g. client) # please keep its loading function consistent with test_pipeline (e.g. client)
...@@ -144,28 +141,31 @@ eval_pipeline = [ ...@@ -144,28 +141,31 @@ eval_pipeline = [
file_client_args=file_client_args, file_client_args=file_client_args,
pad_empty_sweeps=True, pad_empty_sweeps=True,
remove_close=True), remove_close=True),
dict( dict(type='Pack3DDetInputs', keys=['points'])
type='DefaultFormatBundle3D',
class_names=class_names,
with_label=False),
dict(type='Collect3D', keys=['points'])
] ]
train_dataloader = dict(
data = dict( _delete_=True,
train=dict( batch_size=4,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type='CBGSDataset', type='CBGSDataset',
dataset=dict( dataset=dict(
type=dataset_type, type=dataset_type,
data_root=data_root, data_root=data_root,
ann_file=data_root + 'nuscenes_infos_train.pkl', ann_file='nuscenes_infos_train.pkl',
pipeline=train_pipeline, pipeline=train_pipeline,
classes=class_names, metainfo=dict(CLASSES=class_names),
test_mode=False, test_mode=False,
data_prefix=data_prefix,
use_valid_flag=True, use_valid_flag=True,
# we use box_type_3d='LiDAR' in kitti and nuscenes dataset # we use box_type_3d='LiDAR' in kitti and nuscenes dataset
# and box_type_3d='Depth' in sunrgbd and scannet dataset. # and box_type_3d='Depth' in sunrgbd and scannet dataset.
box_type_3d='LiDAR')), box_type_3d='LiDAR')))
val=dict(pipeline=test_pipeline, classes=class_names), test_dataloader = dict(
test=dict(pipeline=test_pipeline, classes=class_names)) dataset=dict(pipeline=test_pipeline, metainfo=dict(CLASSES=class_names)))
val_dataloader = dict(
dataset=dict(pipeline=test_pipeline, metainfo=dict(CLASSES=class_names)))
evaluation = dict(interval=20, pipeline=eval_pipeline) train_cfg = dict(val_interval=20)
...@@ -12,7 +12,7 @@ class_names = [ ...@@ -12,7 +12,7 @@ class_names = [
'car', 'truck', 'construction_vehicle', 'bus', 'trailer', 'barrier', 'car', 'truck', 'construction_vehicle', 'bus', 'trailer', 'barrier',
'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone' 'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone'
] ]
data_prefix = dict(pts='samples/LIDAR_TOP', img='')
model = dict( model = dict(
pts_voxel_layer=dict(point_cloud_range=point_cloud_range), pts_voxel_layer=dict(point_cloud_range=point_cloud_range),
pts_voxel_encoder=dict(point_cloud_range=point_cloud_range), pts_voxel_encoder=dict(point_cloud_range=point_cloud_range),
...@@ -91,8 +91,9 @@ train_pipeline = [ ...@@ -91,8 +91,9 @@ train_pipeline = [
dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range), dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
dict(type='ObjectNameFilter', classes=class_names), dict(type='ObjectNameFilter', classes=class_names),
dict(type='PointShuffle'), dict(type='PointShuffle'),
dict(type='DefaultFormatBundle3D', class_names=class_names), dict(
dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d']) type='Pack3DDetInputs',
keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
] ]
test_pipeline = [ test_pipeline = [
dict( dict(
...@@ -119,13 +120,9 @@ test_pipeline = [ ...@@ -119,13 +120,9 @@ test_pipeline = [
rot_range=[0, 0], rot_range=[0, 0],
scale_ratio_range=[1., 1.], scale_ratio_range=[1., 1.],
translation_std=[0, 0, 0]), translation_std=[0, 0, 0]),
dict(type='RandomFlip3D'), dict(type='RandomFlip3D')
dict( ]),
type='DefaultFormatBundle3D', dict(type='Pack3DDetInputs', keys=['points'])
class_names=class_names,
with_label=False),
dict(type='Collect3D', keys=['points'])
])
] ]
# construct a pipeline for data and gt loading in show function # construct a pipeline for data and gt loading in show function
# please keep its loading function consistent with test_pipeline (e.g. client) # please keep its loading function consistent with test_pipeline (e.g. client)
...@@ -143,28 +140,31 @@ eval_pipeline = [ ...@@ -143,28 +140,31 @@ eval_pipeline = [
file_client_args=file_client_args, file_client_args=file_client_args,
pad_empty_sweeps=True, pad_empty_sweeps=True,
remove_close=True), remove_close=True),
dict( dict(type='Pack3DDetInputs', keys=['points'])
type='DefaultFormatBundle3D',
class_names=class_names,
with_label=False),
dict(type='Collect3D', keys=['points'])
] ]
train_dataloader = dict(
data = dict( _delete_=True,
train=dict( batch_size=4,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type='CBGSDataset', type='CBGSDataset',
dataset=dict( dataset=dict(
type=dataset_type, type=dataset_type,
data_root=data_root, data_root=data_root,
ann_file=data_root + 'nuscenes_infos_train.pkl', ann_file='nuscenes_infos_train.pkl',
pipeline=train_pipeline, pipeline=train_pipeline,
classes=class_names, metainfo=dict(CLASSES=class_names),
test_mode=False, test_mode=False,
data_prefix=data_prefix,
use_valid_flag=True, use_valid_flag=True,
# we use box_type_3d='LiDAR' in kitti and nuscenes dataset # we use box_type_3d='LiDAR' in kitti and nuscenes dataset
# and box_type_3d='Depth' in sunrgbd and scannet dataset. # and box_type_3d='Depth' in sunrgbd and scannet dataset.
box_type_3d='LiDAR')), box_type_3d='LiDAR')))
val=dict(pipeline=test_pipeline, classes=class_names), test_dataloader = dict(
test=dict(pipeline=test_pipeline, classes=class_names)) dataset=dict(pipeline=test_pipeline, metainfo=dict(CLASSES=class_names)))
val_dataloader = dict(
dataset=dict(pipeline=test_pipeline, metainfo=dict(CLASSES=class_names)))
evaluation = dict(interval=20, pipeline=eval_pipeline) train_cfg = dict(val_interval=20)
...@@ -6,8 +6,14 @@ point_cloud_range = [0, -40, -3, 70.4, 40, 1] ...@@ -6,8 +6,14 @@ point_cloud_range = [0, -40, -3, 70.4, 40, 1]
model = dict( model = dict(
type='DynamicMVXFasterRCNN', type='DynamicMVXFasterRCNN',
data_preprocessor=dict(
type='Det3DDataPreprocessor',
mean=[102.9801, 115.9465, 122.7717],
std=[1.0, 1.0, 1.0],
bgr_to_rgb=False,
pad_size_divisor=32),
img_backbone=dict( img_backbone=dict(
type='ResNet', type='mmdet.ResNet',
depth=50, depth=50,
num_stages=4, num_stages=4,
out_indices=(0, 1, 2, 3), out_indices=(0, 1, 2, 3),
...@@ -16,7 +22,7 @@ model = dict( ...@@ -16,7 +22,7 @@ model = dict(
norm_eval=True, norm_eval=True,
style='caffe'), style='caffe'),
img_neck=dict( img_neck=dict(
type='FPN', type='mmdet.FPN',
in_channels=[256, 512, 1024, 2048], in_channels=[256, 512, 1024, 2048],
out_channels=256, out_channels=256,
num_outs=5), num_outs=5),
...@@ -82,34 +88,36 @@ model = dict( ...@@ -82,34 +88,36 @@ model = dict(
assign_per_class=True, assign_per_class=True,
bbox_coder=dict(type='DeltaXYZWLHRBBoxCoder'), bbox_coder=dict(type='DeltaXYZWLHRBBoxCoder'),
loss_cls=dict( loss_cls=dict(
type='FocalLoss', type='mmdet.FocalLoss',
use_sigmoid=True, use_sigmoid=True,
gamma=2.0, gamma=2.0,
alpha=0.25, alpha=0.25,
loss_weight=1.0), loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=2.0), loss_bbox=dict(
type='mmdet.SmoothL1Loss', beta=1.0 / 9.0, loss_weight=2.0),
loss_dir=dict( loss_dir=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.2)), type='mmdet.CrossEntropyLoss', use_sigmoid=False,
loss_weight=0.2)),
# model training and testing settings # model training and testing settings
train_cfg=dict( train_cfg=dict(
pts=dict( pts=dict(
assigner=[ assigner=[
dict( # for Pedestrian dict( # for Pedestrian
type='MaxIoUAssigner', type='Max3DIoUAssigner',
iou_calculator=dict(type='BboxOverlapsNearest3D'), iou_calculator=dict(type='BboxOverlapsNearest3D'),
pos_iou_thr=0.35, pos_iou_thr=0.35,
neg_iou_thr=0.2, neg_iou_thr=0.2,
min_pos_iou=0.2, min_pos_iou=0.2,
ignore_iof_thr=-1), ignore_iof_thr=-1),
dict( # for Cyclist dict( # for Cyclist
type='MaxIoUAssigner', type='Max3DIoUAssigner',
iou_calculator=dict(type='BboxOverlapsNearest3D'), iou_calculator=dict(type='BboxOverlapsNearest3D'),
pos_iou_thr=0.35, pos_iou_thr=0.35,
neg_iou_thr=0.2, neg_iou_thr=0.2,
min_pos_iou=0.2, min_pos_iou=0.2,
ignore_iof_thr=-1), ignore_iof_thr=-1),
dict( # for Car dict( # for Car
type='MaxIoUAssigner', type='Max3DIoUAssigner',
iou_calculator=dict(type='BboxOverlapsNearest3D'), iou_calculator=dict(type='BboxOverlapsNearest3D'),
pos_iou_thr=0.6, pos_iou_thr=0.6,
neg_iou_thr=0.45, neg_iou_thr=0.45,
...@@ -133,18 +141,14 @@ model = dict( ...@@ -133,18 +141,14 @@ model = dict(
dataset_type = 'KittiDataset' dataset_type = 'KittiDataset'
data_root = 'data/kitti/' data_root = 'data/kitti/'
class_names = ['Pedestrian', 'Cyclist', 'Car'] class_names = ['Pedestrian', 'Cyclist', 'Car']
img_norm_cfg = dict( metainfo = dict(CLASSES=class_names)
mean=[103.530, 116.280, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False)
input_modality = dict(use_lidar=True, use_camera=True) input_modality = dict(use_lidar=True, use_camera=True)
train_pipeline = [ train_pipeline = [
dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4), dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4),
dict(type='LoadImageFromFile'), dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True), dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True),
dict( dict(
type='Resize', type='RandomResize', scale=[(640, 192), (2560, 768)], keep_ratio=True),
img_scale=[(640, 192), (2560, 768)],
multiscale_mode='range',
keep_ratio=True),
dict( dict(
type='GlobalRotScaleTrans', type='GlobalRotScaleTrans',
rot_range=[-0.78539816, 0.78539816], rot_range=[-0.78539816, 0.78539816],
...@@ -154,12 +158,12 @@ train_pipeline = [ ...@@ -154,12 +158,12 @@ train_pipeline = [
dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range), dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range),
dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range), dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
dict(type='PointShuffle'), dict(type='PointShuffle'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle3D', class_names=class_names),
dict( dict(
type='Collect3D', type='Pack3DDetInputs',
keys=['points', 'img', 'gt_bboxes_3d', 'gt_labels_3d']), keys=[
'points', 'img', 'gt_bboxes_3d', 'gt_labels_3d', 'gt_bboxes',
'gt_labels'
])
] ]
test_pipeline = [ test_pipeline = [
dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4), dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4),
...@@ -170,82 +174,79 @@ test_pipeline = [ ...@@ -170,82 +174,79 @@ test_pipeline = [
pts_scale_ratio=1, pts_scale_ratio=1,
flip=False, flip=False,
transforms=[ transforms=[
dict(type='Resize', multiscale_mode='value', keep_ratio=True), # Temporary solution, fix this after refactor the augtest
dict(type='Resize', scale=0, keep_ratio=True),
dict( dict(
type='GlobalRotScaleTrans', type='GlobalRotScaleTrans',
rot_range=[0, 0], rot_range=[0, 0],
scale_ratio_range=[1., 1.], scale_ratio_range=[1., 1.],
translation_std=[0, 0, 0]), translation_std=[0, 0, 0]),
dict(type='RandomFlip3D'), dict(type='RandomFlip3D'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict( dict(
type='PointsRangeFilter', point_cloud_range=point_cloud_range), type='PointsRangeFilter', point_cloud_range=point_cloud_range),
dict( ]),
type='DefaultFormatBundle3D', dict(type='Pack3DDetInputs', keys=['points', 'img'])
class_names=class_names,
with_label=False),
dict(type='Collect3D', keys=['points', 'img'])
])
] ]
# construct a pipeline for data and gt loading in show function modality = dict(use_lidar=True, use_camera=True)
# please keep its loading function consistent with test_pipeline (e.g. client) train_dataloader = dict(
eval_pipeline = [ batch_size=2,
dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4), num_workers=2,
dict(type='LoadImageFromFile'), sampler=dict(type='DefaultSampler', shuffle=True),
dict( dataset=dict(
type='DefaultFormatBundle3D',
class_names=class_names,
with_label=False),
dict(type='Collect3D', keys=['points', 'img'])
]
data = dict(
samples_per_gpu=2,
workers_per_gpu=2,
train=dict(
type='RepeatDataset', type='RepeatDataset',
times=2, times=2,
dataset=dict( dataset=dict(
type=dataset_type, type=dataset_type,
data_root=data_root, data_root=data_root,
ann_file=data_root + 'kitti_infos_train.pkl', modality=modality,
split='training', ann_file='kitti_infos_train.pkl',
pts_prefix='velodyne_reduced', data_prefix=dict(
pts='training/velodyne_reduced', img='training/image_2'),
pipeline=train_pipeline, pipeline=train_pipeline,
modality=input_modality, filter_empty_gt=False,
classes=class_names, metainfo=metainfo,
test_mode=False, # we use box_type_3d='LiDAR' in kitti and nuscenes dataset
box_type_3d='LiDAR')), # and box_type_3d='Depth' in sunrgbd and scannet dataset.
val=dict( box_type_3d='LiDAR')))
val_dataloader = dict(
batch_size=1,
num_workers=1,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type, type=dataset_type,
data_root=data_root, data_root=data_root,
ann_file=data_root + 'kitti_infos_val.pkl', modality=modality,
split='training', ann_file='kitti_infos_val.pkl',
pts_prefix='velodyne_reduced', data_prefix=dict(
pts='training/velodyne_reduced', img='training/image_2'),
pipeline=test_pipeline, pipeline=test_pipeline,
modality=input_modality, metainfo=metainfo,
classes=class_names,
test_mode=True, test_mode=True,
box_type_3d='LiDAR'), box_type_3d='LiDAR'))
test=dict( test_dataloader = dict(
batch_size=1,
num_workers=1,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type, type=dataset_type,
data_root=data_root, data_root=data_root,
ann_file=data_root + 'kitti_infos_val.pkl', ann_file='kitti_infos_val.pkl',
split='training', modality=modality,
pts_prefix='velodyne_reduced', data_prefix=dict(
pts='training/velodyne_reduced', img='training/image_2'),
pipeline=test_pipeline, pipeline=test_pipeline,
modality=input_modality, metainfo=metainfo,
classes=class_names,
test_mode=True, test_mode=True,
box_type_3d='LiDAR')) box_type_3d='LiDAR'))
# Training settings optim_wrapper = dict(
optimizer = dict(weight_decay=0.01) optimizer=dict(weight_decay=0.01),
# max_norm=10 is better for SECOND clip_grad=dict(max_norm=35, norm_type=2),
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2)) )
val_evaluator = dict(
evaluation = dict(interval=1, pipeline=eval_pipeline) type='KittiMetric', ann_file='data/kitti/kitti_infos_val.pkl')
test_evaluator = val_evaluator
# You may need to download the model first is the network is unstable # You may need to download the model first is the network is unstable
load_from = 'https://download.openmmlab.com/mmdetection3d/pretrain_models/mvx_faster_rcnn_detectron2-caffe_20e_coco-pretrain_gt-sample_kitti-3-class_moderate-79.3_20200207-a4a6a3c7.pth' # noqa load_from = 'https://download.openmmlab.com/mmdetection3d/pretrain_models/mvx_faster_rcnn_detectron2-caffe_20e_coco-pretrain_gt-sample_kitti-3-class_moderate-79.3_20200207-a4a6a3c7.pth' # noqa
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .assigners import AssignResult, BaseAssigner, MaxIoUAssigner from .assigners import AssignResult, BaseAssigner, Max3DIoUAssigner
# from .bbox_target import bbox_target # from .bbox_target import bbox_target
from .builder import build_assigner, build_bbox_coder, build_sampler from .builder import build_assigner, build_bbox_coder, build_sampler
from .coders import DeltaXYZWLHRBBoxCoder from .coders import DeltaXYZWLHRBBoxCoder
...@@ -18,7 +18,7 @@ from .structures import (BaseInstance3DBoxes, Box3DMode, CameraInstance3DBoxes, ...@@ -18,7 +18,7 @@ from .structures import (BaseInstance3DBoxes, Box3DMode, CameraInstance3DBoxes,
from .transforms import bbox3d2result, bbox3d2roi, bbox3d_mapping_back from .transforms import bbox3d2result, bbox3d2roi, bbox3d_mapping_back
__all__ = [ __all__ = [
'BaseSampler', 'AssignResult', 'BaseAssigner', 'MaxIoUAssigner', 'BaseSampler', 'AssignResult', 'BaseAssigner', 'Max3DIoUAssigner',
'PseudoSampler', 'RandomSampler', 'InstanceBalancedPosSampler', 'PseudoSampler', 'RandomSampler', 'InstanceBalancedPosSampler',
'IoUBalancedNegSampler', 'CombinedSampler', 'SamplingResult', 'IoUBalancedNegSampler', 'CombinedSampler', 'SamplingResult',
'DeltaXYZWLHRBBoxCoder', 'BboxOverlapsNearest3D', 'BboxOverlaps3D', 'DeltaXYZWLHRBBoxCoder', 'BboxOverlapsNearest3D', 'BboxOverlaps3D',
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmdet.core.bbox import AssignResult, BaseAssigner from mmdet.core.bbox import AssignResult, BaseAssigner
from .max_3d_iou_assigner import MaxIoUAssigner from .max_3d_iou_assigner import Max3DIoUAssigner
__all__ = ['BaseAssigner', 'MaxIoUAssigner', 'AssignResult'] __all__ = ['BaseAssigner', 'Max3DIoUAssigner', 'AssignResult']
...@@ -35,6 +35,8 @@ class Det3DDataset(BaseDataset): ...@@ -35,6 +35,8 @@ class Det3DDataset(BaseDataset):
- use_camera: bool - use_camera: bool
- use_lidar: bool - use_lidar: bool
Defaults to `dict(use_lidar=True, use_camera=False)` Defaults to `dict(use_lidar=True, use_camera=False)`
default_cam_key (str, optional): The default camera name adopted.
Defaults to None.
box_type_3d (str, optional): Type of 3D box of this dataset. box_type_3d (str, optional): Type of 3D box of this dataset.
Based on the `box_type_3d`, the dataset will encapsulate the box Based on the `box_type_3d`, the dataset will encapsulate the box
to its original format then converted them to `box_type_3d`. to its original format then converted them to `box_type_3d`.
...@@ -65,6 +67,7 @@ class Det3DDataset(BaseDataset): ...@@ -65,6 +67,7 @@ class Det3DDataset(BaseDataset):
data_prefix: dict = dict(pts='velodyne', img=''), data_prefix: dict = dict(pts='velodyne', img=''),
pipeline: List[Union[dict, Callable]] = [], pipeline: List[Union[dict, Callable]] = [],
modality: dict = dict(use_lidar=True, use_camera=False), modality: dict = dict(use_lidar=True, use_camera=False),
default_cam_key: str = None,
box_type_3d: dict = 'LiDAR', box_type_3d: dict = 'LiDAR',
filter_empty_gt: bool = True, filter_empty_gt: bool = True,
test_mode: bool = False, test_mode: bool = False,
...@@ -84,6 +87,7 @@ class Det3DDataset(BaseDataset): ...@@ -84,6 +87,7 @@ class Det3DDataset(BaseDataset):
if key not in modality: if key not in modality:
modality[key] = False modality[key] = False
self.modality = modality self.modality = modality
self.default_cam_key = default_cam_key
assert self.modality['use_lidar'] or self.modality['use_camera'], ( assert self.modality['use_lidar'] or self.modality['use_camera'], (
'Please specify the `modality` (`use_lidar` ' 'Please specify the `modality` (`use_lidar` '
f', `use_camera`) for {self.__class__.__name__}') f', `use_camera`) for {self.__class__.__name__}')
...@@ -233,6 +237,20 @@ class Det3DDataset(BaseDataset): ...@@ -233,6 +237,20 @@ class Det3DDataset(BaseDataset):
cam_prefix = self.data_prefix.get('img', '') cam_prefix = self.data_prefix.get('img', '')
img_info['img_path'] = osp.join(cam_prefix, img_info['img_path'] = osp.join(cam_prefix,
img_info['img_path']) img_info['img_path'])
if self.default_cam_key is not None:
info['img_path'] = info['images'][
self.default_cam_key]['img_path']
if 'lidar2cam' in info['images'][self.default_cam_key]:
info['lidar2cam'] = np.array(
info['images'][self.default_cam_key]['lidar2cam'])
if 'cam2img' in info['images'][self.default_cam_key]:
info['cam2img'] = np.array(
info['images'][self.default_cam_key]['cam2img'])
if 'lidar2img' in info['images'][self.default_cam_key]:
info['lidar2img'] = np.array(
info['images'][self.default_cam_key]['lidar2img'])
else:
info['lidar2img'] = info['cam2img'] @ info['lidar2cam']
if not self.test_mode: if not self.test_mode:
# used in traing # used in traing
......
...@@ -49,6 +49,7 @@ class KittiDataset(Det3DDataset): ...@@ -49,6 +49,7 @@ class KittiDataset(Det3DDataset):
ann_file: str, ann_file: str,
pipeline: List[Union[dict, Callable]] = [], pipeline: List[Union[dict, Callable]] = [],
modality: Optional[dict] = dict(use_lidar=True), modality: Optional[dict] = dict(use_lidar=True),
default_cam_key='CAM2',
box_type_3d: str = 'LiDAR', box_type_3d: str = 'LiDAR',
filter_empty_gt: bool = True, filter_empty_gt: bool = True,
test_mode: bool = False, test_mode: bool = False,
...@@ -61,6 +62,7 @@ class KittiDataset(Det3DDataset): ...@@ -61,6 +62,7 @@ class KittiDataset(Det3DDataset):
ann_file=ann_file, ann_file=ann_file,
pipeline=pipeline, pipeline=pipeline,
modality=modality, modality=modality,
default_cam_key=default_cam_key,
box_type_3d=box_type_3d, box_type_3d=box_type_3d,
filter_empty_gt=filter_empty_gt, filter_empty_gt=filter_empty_gt,
test_mode=test_mode, test_mode=test_mode,
......
...@@ -111,9 +111,9 @@ class Base3DDenseHead(BaseModule, metaclass=ABCMeta): ...@@ -111,9 +111,9 @@ class Base3DDenseHead(BaseModule, metaclass=ABCMeta):
Args: Args:
x (tuple[Tensor]): Features from FPN. x (tuple[Tensor]): Features from FPN.
batch_data_samples (list[:obj:`DetDataSample`]): Each item contains batch_data_samples (list[:obj:`Det3DDataSample`]): Each item
the meta information of each image and corresponding contains the meta information of each image and
annotations. corresponding annotations.
proposal_cfg (ConfigDict, optional): Test / postprocessing proposal_cfg (ConfigDict, optional): Test / postprocessing
configuration, if None, test_cfg would be used. configuration, if None, test_cfg would be used.
Defaults to None. Defaults to None.
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import copy import copy
from typing import Dict, List, Optional, Tuple, Union
import torch import torch
from mmcv.cnn import ConvModule, build_conv_layer from mmcv.cnn import ConvModule, build_conv_layer
from mmcv.runner import BaseModule, force_fp32 from mmcv.runner import BaseModule, force_fp32
from torch import nn from mmengine import InstanceData
from torch import Tensor, nn
from mmdet3d.core import (circle_nms, draw_heatmap_gaussian, gaussian_radius, from mmdet3d.core import (Det3DDataSample, circle_nms, draw_heatmap_gaussian,
xywhr2xyxyr) gaussian_radius, xywhr2xyxyr)
from mmdet3d.core.post_processing import nms_bev from mmdet3d.core.post_processing import nms_bev
from mmdet3d.models import builder from mmdet3d.models import builder
from mmdet3d.models.builder import build_loss
from mmdet3d.models.utils import clip_sigmoid from mmdet3d.models.utils import clip_sigmoid
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS, TASK_UTILS
from mmdet.core import build_bbox_coder, multi_apply from mmdet.core import multi_apply
@MODELS.register_module() @MODELS.register_module()
...@@ -53,7 +54,6 @@ class SeparateHead(BaseModule): ...@@ -53,7 +54,6 @@ class SeparateHead(BaseModule):
self.init_bias = init_bias self.init_bias = init_bias
for head in self.heads: for head in self.heads:
classes, num_conv = self.heads[head] classes, num_conv = self.heads[head]
conv_layers = [] conv_layers = []
c_in = in_channels c_in = in_channels
for i in range(num_conv - 1): for i in range(num_conv - 1):
...@@ -250,8 +250,6 @@ class CenterHead(BaseModule): ...@@ -250,8 +250,6 @@ class CenterHead(BaseModule):
feature map. Default: [128]. feature map. Default: [128].
tasks (list[dict], optional): Task information including class number tasks (list[dict], optional): Task information including class number
and class names. Default: None. and class names. Default: None.
train_cfg (dict, optional): Train-time configs. Default: None.
test_cfg (dict, optional): Test-time configs. Default: None.
bbox_coder (dict, optional): Bbox coder configs. Default: None. bbox_coder (dict, optional): Bbox coder configs. Default: None.
common_heads (dict, optional): Conv information for common heads. common_heads (dict, optional): Conv information for common heads.
Default: dict(). Default: dict().
...@@ -269,32 +267,45 @@ class CenterHead(BaseModule): ...@@ -269,32 +267,45 @@ class CenterHead(BaseModule):
Default: dict(type='Conv2d') Default: dict(type='Conv2d')
norm_cfg (dict, optional): Config of norm layer. norm_cfg (dict, optional): Config of norm layer.
Default: dict(type='BN2d'). Default: dict(type='BN2d').
bias (str, optional): Type of bias. Default: 'auto'. bias (str): Type of bias. Default: 'auto'.
norm_bbox (bool): Whether normalize the bbox predictions.
Defaults to True.
train_cfg (dict, optional): Train-time configs. Default: None.
test_cfg (dict, optional): Test-time configs. Default: None.
init_cfg (dict, optional): Config for initialization.
""" """
def __init__(self, def __init__(self,
in_channels=[128], in_channels: Union[List[int], int] = [128],
tasks=None, tasks: Optional[List[dict]] = None,
train_cfg=None, bbox_coder: Optional[dict] = None,
test_cfg=None, common_heads: dict = dict(),
bbox_coder=None, loss_cls: dict = dict(
common_heads=dict(), type='mmdet.GaussianFocalLoss', reduction='mean'),
loss_cls=dict(type='GaussianFocalLoss', reduction='mean'), loss_bbox: dict = dict(
loss_bbox=dict( type='mmdet.L1Loss', reduction='none', loss_weight=0.25),
type='L1Loss', reduction='none', loss_weight=0.25), separate_head: dict = dict(
separate_head=dict( type='mmdet.SeparateHead',
type='SeparateHead', init_bias=-2.19, final_kernel=3), init_bias=-2.19,
share_conv_channel=64, final_kernel=3),
num_heatmap_convs=2, share_conv_channel: int = 64,
conv_cfg=dict(type='Conv2d'), num_heatmap_convs: int = 2,
norm_cfg=dict(type='BN2d'), conv_cfg: dict = dict(type='Conv2d'),
bias='auto', norm_cfg: dict = dict(type='BN2d'),
norm_bbox=True, bias: str = 'auto',
init_cfg=None): norm_bbox: bool = True,
train_cfg: Optional[dict] = None,
test_cfg: Optional[dict] = None,
init_cfg: Optional[dict] = None,
**kwargs):
assert init_cfg is None, 'To prevent abnormal initialization ' \ assert init_cfg is None, 'To prevent abnormal initialization ' \
'behavior, init_cfg is not allowed to be set' 'behavior, init_cfg is not allowed to be set'
super(CenterHead, self).__init__(init_cfg=init_cfg) super(CenterHead, self).__init__(init_cfg=init_cfg, **kwargs)
# TODO we should rename this variable,
# for example num_classes_per_task ?
# {'num_class': 2, 'class_names': ['pedestrian', 'traffic_cone']}]
# TODO seems num_classes is useless
num_classes = [len(t['class_names']) for t in tasks] num_classes = [len(t['class_names']) for t in tasks]
self.class_names = [t['class_names'] for t in tasks] self.class_names = [t['class_names'] for t in tasks]
self.train_cfg = train_cfg self.train_cfg = train_cfg
...@@ -303,9 +314,9 @@ class CenterHead(BaseModule): ...@@ -303,9 +314,9 @@ class CenterHead(BaseModule):
self.num_classes = num_classes self.num_classes = num_classes
self.norm_bbox = norm_bbox self.norm_bbox = norm_bbox
self.loss_cls = build_loss(loss_cls) self.loss_cls = MODELS.build(loss_cls)
self.loss_bbox = build_loss(loss_bbox) self.loss_bbox = MODELS.build(loss_bbox)
self.bbox_coder = build_bbox_coder(bbox_coder) self.bbox_coder = TASK_UTILS.build(bbox_coder)
self.num_anchor_per_locs = [n for n in num_classes] self.num_anchor_per_locs = [n for n in num_classes]
self.fp16_enabled = False self.fp16_enabled = False
...@@ -328,7 +339,7 @@ class CenterHead(BaseModule): ...@@ -328,7 +339,7 @@ class CenterHead(BaseModule):
in_channels=share_conv_channel, heads=heads, num_cls=num_cls) in_channels=share_conv_channel, heads=heads, num_cls=num_cls)
self.task_heads.append(builder.build_head(separate_head)) self.task_heads.append(builder.build_head(separate_head))
def forward_single(self, x): def forward_single(self, x: Tensor) -> dict:
"""Forward function for CenterPoint. """Forward function for CenterPoint.
Args: Args:
...@@ -347,7 +358,7 @@ class CenterHead(BaseModule): ...@@ -347,7 +358,7 @@ class CenterHead(BaseModule):
return ret_dicts return ret_dicts
def forward(self, feats): def forward(self, feats: List[Tensor]) -> Tuple[List[Tensor]]:
"""Forward pass. """Forward pass.
Args: Args:
...@@ -384,7 +395,10 @@ class CenterHead(BaseModule): ...@@ -384,7 +395,10 @@ class CenterHead(BaseModule):
feat = feat.view(-1, dim) feat = feat.view(-1, dim)
return feat return feat
def get_targets(self, gt_bboxes_3d, gt_labels_3d): def get_targets(
self,
batch_gt_instances_3d: List[InstanceData],
) -> Tuple[List[Tensor]]:
"""Generate targets. """Generate targets.
How each output is transformed: How each output is transformed:
...@@ -399,24 +413,24 @@ class CenterHead(BaseModule): ...@@ -399,24 +413,24 @@ class CenterHead(BaseModule):
[ tensor0, tensor1, tensor2, ... ] [ tensor0, tensor1, tensor2, ... ]
Args: Args:
gt_bboxes_3d (list[:obj:`LiDARInstance3DBoxes`]): Ground batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
truth gt boxes. gt_instances. It usually includes ``bboxes_3d`` and\
gt_labels_3d (list[torch.Tensor]): Labels of boxes. ``labels_3d`` attributes.
Returns: Returns:
Returns: Returns:
tuple[list[torch.Tensor]]: Tuple of target including tuple[list[torch.Tensor]]: Tuple of target including
the following results in order. the following results in order.
- list[torch.Tensor]: Heatmap scores. - list[torch.Tensor]: Heatmap scores.
- list[torch.Tensor]: Ground truth boxes. - list[torch.Tensor]: Ground truth boxes.
- list[torch.Tensor]: Indexes indicating the - list[torch.Tensor]: Indexes indicating the
position of the valid boxes. position of the valid boxes.
- list[torch.Tensor]: Masks indicating which - list[torch.Tensor]: Masks indicating which
boxes are valid. boxes are valid.
""" """
heatmaps, anno_boxes, inds, masks = multi_apply( heatmaps, anno_boxes, inds, masks = multi_apply(
self.get_targets_single, gt_bboxes_3d, gt_labels_3d) self.get_targets_single, batch_gt_instances_3d)
# Transpose heatmaps # Transpose heatmaps
heatmaps = list(map(list, zip(*heatmaps))) heatmaps = list(map(list, zip(*heatmaps)))
heatmaps = [torch.stack(hms_) for hms_ in heatmaps] heatmaps = [torch.stack(hms_) for hms_ in heatmaps]
...@@ -431,12 +445,14 @@ class CenterHead(BaseModule): ...@@ -431,12 +445,14 @@ class CenterHead(BaseModule):
masks = [torch.stack(masks_) for masks_ in masks] masks = [torch.stack(masks_) for masks_ in masks]
return heatmaps, anno_boxes, inds, masks return heatmaps, anno_boxes, inds, masks
def get_targets_single(self, gt_bboxes_3d, gt_labels_3d): def get_targets_single(self,
gt_instances_3d: InstanceData) -> Tuple[Tensor]:
"""Generate training targets for a single sample. """Generate training targets for a single sample.
Args: Args:
gt_bboxes_3d (:obj:`LiDARInstance3DBoxes`): Ground truth gt boxes. gt_instances_3d (:obj:`InstanceData`): Gt_instances of
gt_labels_3d (torch.Tensor): Labels of boxes. single data sample. It usually includes
``bboxes_3d`` and ``labels_3d`` attributes.
Returns: Returns:
tuple[list[torch.Tensor]]: Tuple of target including tuple[list[torch.Tensor]]: Tuple of target including
...@@ -449,6 +465,8 @@ class CenterHead(BaseModule): ...@@ -449,6 +465,8 @@ class CenterHead(BaseModule):
- list[torch.Tensor]: Masks indicating which boxes - list[torch.Tensor]: Masks indicating which boxes
are valid. are valid.
""" """
gt_labels_3d = gt_instances_3d.labels_3d
gt_bboxes_3d = gt_instances_3d.bboxes_3d
device = gt_labels_3d.device device = gt_labels_3d.device
gt_bboxes_3d = torch.cat( gt_bboxes_3d = torch.cat(
(gt_bboxes_3d.gravity_center, gt_bboxes_3d.tensor[:, 3:]), (gt_bboxes_3d.gravity_center, gt_bboxes_3d.tensor[:, 3:]),
...@@ -569,21 +587,48 @@ class CenterHead(BaseModule): ...@@ -569,21 +587,48 @@ class CenterHead(BaseModule):
inds.append(ind) inds.append(ind)
return heatmaps, anno_boxes, inds, masks return heatmaps, anno_boxes, inds, masks
def loss(self, pts_feats: List[Tensor],
batch_data_samples: List[Det3DDataSample], *args,
**kwargs) -> Dict[str, Tensor]:
"""Forward function for point cloud branch.
Args:
pts_feats (list[torch.Tensor]): Features of point cloud branch
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance_3d`, .
Returns:
dict: Losses of each branch.
"""
outs = self(pts_feats)
batch_gt_instance_3d = []
for data_sample in batch_data_samples:
batch_gt_instance_3d.append(data_sample.gt_instances_3d)
losses = self.loss_by_feat(outs, batch_gt_instance_3d)
return losses
@force_fp32(apply_to=('preds_dicts')) @force_fp32(apply_to=('preds_dicts'))
def loss(self, gt_bboxes_3d, gt_labels_3d, preds_dicts, **kwargs): def loss_by_feat(self, preds_dicts: Tuple[List[dict]],
batch_gt_instances_3d: List[InstanceData], *args,
**kwargs):
"""Loss function for CenterHead. """Loss function for CenterHead.
Args: Args:
gt_bboxes_3d (list[:obj:`LiDARInstance3DBoxes`]): Ground preds_dicts (tuple[list[dict]]): Prediction results of
truth gt boxes. multiple tasks. The outer tuple indicate different
gt_labels_3d (list[torch.Tensor]): Labels of boxes. tasks head, and the internal list indicate different
preds_dicts (dict): Output of forward function. FPN level.
batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
gt_instances. It usually includes ``bboxes_3d`` and\
``labels_3d`` attributes.
Returns: Returns:
dict[str:torch.Tensor]: Loss of heatmap and bbox of each task. dict[str,torch.Tensor]: Loss of heatmap and bbox of each task.
""" """
heatmaps, anno_boxes, inds, masks = self.get_targets( heatmaps, anno_boxes, inds, masks = self.get_targets(
gt_bboxes_3d, gt_labels_3d) batch_gt_instances_3d)
loss_dict = dict() loss_dict = dict()
for task_id, preds_dict in enumerate(preds_dicts): for task_id, preds_dict in enumerate(preds_dicts):
# heatmap focal loss # heatmap focal loss
...@@ -619,15 +664,62 @@ class CenterHead(BaseModule): ...@@ -619,15 +664,62 @@ class CenterHead(BaseModule):
loss_dict[f'task{task_id}.loss_bbox'] = loss_bbox loss_dict[f'task{task_id}.loss_bbox'] = loss_bbox
return loss_dict return loss_dict
def get_bboxes(self, preds_dicts, img_metas, img=None, rescale=False): def predict(self,
pts_feats: Dict[str, torch.Tensor],
batch_data_samples: List[Det3DDataSample],
rescale=True,
**kwargs) -> List[InstanceData]:
"""
Args:
pts_feats (dict): Point features..
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
Samples. It usually includes meta information of data.
rescale (bool): Whether rescale the resutls to
the original scale.
Returns:
list[:obj:`InstanceData`]: List of processed predictions. Each
InstanceData contains 3d Bounding boxes and corresponding
scores and labels.
"""
preds_dict = self(pts_feats)
batch_size = len(batch_data_samples)
batch_input_metas = []
for batch_index in range(batch_size):
metainfo = batch_data_samples[batch_index].metainfo
batch_input_metas.append(metainfo)
results_list = self.predict_by_feat(
preds_dict, batch_input_metas, rescale=rescale, **kwargs)
return results_list
def predict_by_feat(self, preds_dicts: Tuple[List[dict]],
batch_input_metas: List[dict], *args,
**kwargs) -> List[InstanceData]:
"""Generate bboxes from bbox head predictions. """Generate bboxes from bbox head predictions.
Args: Args:
preds_dicts (tuple[list[dict]]): Prediction results. preds_dicts (tuple[list[dict]]): Prediction results of
img_metas (list[dict]): Point cloud and image's meta info. multiple tasks. The outer tuple indicate different
tasks head, and the internal list indicate different
FPN level.
batch_input_metas (list[dict]): Meta info of multiple
inputs.
Returns: Returns:
list[dict]: Decoded bbox, scores and labels after nms. list[:obj:`InstanceData`]: Instance prediction
results of each sample after the post process.
Each item usually contains following keys.
- scores_3d (Tensor): Classification scores, has a shape
(num_instance, )
- labels_3d (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes_3d (:obj:`LiDARInstance3DBoxes`): Prediction
of bboxes, contains a tensor with shape
(num_instances, 7) or (num_instances, 9), and
the last 2 dimensions of 9 is
velocity.
""" """
rets = [] rets = []
for task_id, preds_dict in enumerate(preds_dicts): for task_id, preds_dict in enumerate(preds_dicts):
...@@ -689,18 +781,20 @@ class CenterHead(BaseModule): ...@@ -689,18 +781,20 @@ class CenterHead(BaseModule):
rets.append( rets.append(
self.get_task_detections(num_class_with_bg, self.get_task_detections(num_class_with_bg,
batch_cls_preds, batch_reg_preds, batch_cls_preds, batch_reg_preds,
batch_cls_labels, img_metas)) batch_cls_labels,
batch_input_metas))
# Merge branches results # Merge branches results
num_samples = len(rets[0]) num_samples = len(rets[0])
ret_list = [] ret_list = []
for i in range(num_samples): for i in range(num_samples):
temp_instances = InstanceData()
for k in rets[0][i].keys(): for k in rets[0][i].keys():
if k == 'bboxes': if k == 'bboxes':
bboxes = torch.cat([ret[i][k] for ret in rets]) bboxes = torch.cat([ret[i][k] for ret in rets])
bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 5] * 0.5 bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 5] * 0.5
bboxes = img_metas[i]['box_type_3d']( bboxes = batch_input_metas[i]['box_type_3d'](
bboxes, self.bbox_coder.code_size) bboxes, self.bbox_coder.code_size)
elif k == 'scores': elif k == 'scores':
scores = torch.cat([ret[i][k] for ret in rets]) scores = torch.cat([ret[i][k] for ret in rets])
...@@ -710,7 +804,10 @@ class CenterHead(BaseModule): ...@@ -710,7 +804,10 @@ class CenterHead(BaseModule):
rets[j][i][k] += flag rets[j][i][k] += flag
flag += num_class flag += num_class
labels = torch.cat([ret[i][k].int() for ret in rets]) labels = torch.cat([ret[i][k].int() for ret in rets])
ret_list.append([bboxes, scores, labels]) temp_instances.bboxes_3d = bboxes
temp_instances.scores_3d = scores
temp_instances.labels_3d = labels
ret_list.append(temp_instances)
return ret_list return ret_list
def get_task_detections(self, num_class_with_bg, batch_cls_preds, def get_task_detections(self, num_class_with_bg, batch_cls_preds,
......
...@@ -4,7 +4,7 @@ from typing import Dict, List, Optional, Union ...@@ -4,7 +4,7 @@ from typing import Dict, List, Optional, Union
import numpy as np import numpy as np
import torch import torch
from mmcv.ops import furthest_point_sample from mmcv.ops import furthest_point_sample
from mmcv.runner import BaseModule, force_fp32 from mmcv.runner import BaseModule
from mmengine import ConfigDict, InstanceData from mmengine import ConfigDict, InstanceData
from torch.nn import functional as F from torch.nn import functional as F
...@@ -308,7 +308,6 @@ class VoteHead(BaseModule): ...@@ -308,7 +308,6 @@ class VoteHead(BaseModule):
results.update(decode_res) results.update(decode_res)
return results return results
@force_fp32(apply_to=('bbox_preds', ))
def loss_by_feat( def loss_by_feat(
self, self,
points: List[torch.Tensor], points: List[torch.Tensor],
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Union from typing import List, Union
from mmengine import InstanceData
from mmdet3d.core import Det3DDataSample from mmdet3d.core import Det3DDataSample
from mmdet3d.core.utils import (ForwardResults, InstanceList, OptConfigType, from mmdet3d.core.utils import (ForwardResults, InstanceList, OptConfigType,
OptMultiConfig, OptSampleList, SampleList) OptMultiConfig, OptSampleList, SampleList)
...@@ -38,7 +40,7 @@ class Base3DDetector(BaseDetector): ...@@ -38,7 +40,7 @@ class Base3DDetector(BaseDetector):
- "tensor": Forward the whole network and return tensor or tuple of - "tensor": Forward the whole network and return tensor or tuple of
tensor without any post-processing, same as a common nn.Module. tensor without any post-processing, same as a common nn.Module.
- "predict": Forward and return the predictions, which are fully - "predict": Forward and return the predictions, which are fully
processed to a list of :obj:`DetDataSample`. processed to a list of :obj:`Det3DDataSample`.
- "loss": Forward and return a dict of losses according to the given - "loss": Forward and return a dict of losses according to the given
inputs and data samples. inputs and data samples.
...@@ -53,8 +55,8 @@ class Base3DDetector(BaseDetector): ...@@ -53,8 +55,8 @@ class Base3DDetector(BaseDetector):
- points (list[torch.Tensor]): Point cloud of each sample. - points (list[torch.Tensor]): Point cloud of each sample.
- imgs (torch.Tensor): Image tensor has shape (B, C, H, W). - imgs (torch.Tensor): Image tensor has shape (B, C, H, W).
data_samples (list[:obj:`DetDataSample`], data_samples (list[:obj:`Det3DDataSample`],
list[list[:obj:`DetDataSample`]], optional): The list[list[:obj:`Det3DDataSample`]], optional): The
annotation data of every samples. When it is a list[list], the annotation data of every samples. When it is a list[list], the
outer list indicate the test time augmentation, and the outer list indicate the test time augmentation, and the
inter list indicate the batch. Otherwise, the list simply inter list indicate the batch. Otherwise, the list simply
...@@ -65,7 +67,7 @@ class Base3DDetector(BaseDetector): ...@@ -65,7 +67,7 @@ class Base3DDetector(BaseDetector):
The return type depends on ``mode``. The return type depends on ``mode``.
- If ``mode="tensor"``, return a tensor or a tuple of tensor. - If ``mode="tensor"``, return a tensor or a tuple of tensor.
- If ``mode="predict"``, return a list of :obj:`DetDataSample`. - If ``mode="predict"``, return a list of :obj:`Det3DDataSample`.
- If ``mode="loss"``, return a dict of tensor. - If ``mode="loss"``, return a dict of tensor.
""" """
if mode == 'loss': if mode == 'loss':
...@@ -87,7 +89,11 @@ class Base3DDetector(BaseDetector): ...@@ -87,7 +89,11 @@ class Base3DDetector(BaseDetector):
raise RuntimeError(f'Invalid mode "{mode}". ' raise RuntimeError(f'Invalid mode "{mode}". '
'Only supports loss, predict and tensor mode') 'Only supports loss, predict and tensor mode')
def convert_to_datasample(self, results_list: InstanceList) -> SampleList: def convert_to_datasample(
self,
results_list_3d: InstanceList,
results_list_2d: InstanceList = None,
) -> SampleList:
"""Convert results list to `Det3DDataSample`. """Convert results list to `Det3DDataSample`.
Subclasses could override it to be compatible for some multi-modality Subclasses could override it to be compatible for some multi-modality
...@@ -100,19 +106,35 @@ class Base3DDetector(BaseDetector): ...@@ -100,19 +106,35 @@ class Base3DDetector(BaseDetector):
Returns: Returns:
list[:obj:`Det3DDataSample`]: Detection results of the list[:obj:`Det3DDataSample`]: Detection results of the
input. Each Det3DDataSample usually contains input. Each Det3DDataSample usually contains
'pred_instances_3d'. And the ``pred_instances_3d`` usually 'pred_instances_3d'. And the ``pred_instances_3d`` normally
contains following keys.
- scores_3d (Tensor): Classification scores, has a shape
(num_instance, )
- labels_3d (Tensor): Labels of 3D bboxes, has a shape
(num_instances, ).
- bboxes_3d (Tensor): Contains a tensor with shape
(num_instances, C) where C >=7.
When there are image prediction in some models, it should
contains `pred_instances`, And the ``pred_instances`` normally
contains following keys. contains following keys.
- scores_3d (Tensor): Classification scores, has a shape - scores (Tensor): Classification scores of image, has a shape
(num_instance, ) (num_instance, )
- labels_3d (Tensor): Labels of 3D bboxes, has a shape - labels (Tensor): Predict Labels of 2D bboxes, has a shape
(num_instances, ). (num_instances, ).
- bboxes_3d (Tensor): Contains a tensor with shape - bboxes (Tensor): Contains a tensor with shape
(num_instances, C) where C >=7. (num_instances, 4).
""" """
out_results_list = []
for i in range(len(results_list)): data_sample_list = []
if results_list_2d is None:
results_list_2d = [
InstanceData() for _ in range(len(results_list_3d))
]
for i in range(len(results_list_3d)):
result = Det3DDataSample() result = Det3DDataSample()
result.pred_instances_3d = results_list[i] result.pred_instances_3d = results_list_3d[i]
out_results_list.append(result) result.pred_instances = results_list_2d[i]
return out_results_list data_sample_list.append(result)
return data_sample_list
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional
import torch import torch
from mmdet3d.core import bbox3d2result, merge_aug_bboxes_3d from mmdet3d.core import merge_aug_bboxes_3d
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
from .mvx_two_stage import MVXTwoStageDetector from .mvx_two_stage import MVXTwoStageDetector
@MODELS.register_module() @MODELS.register_module()
class CenterPoint(MVXTwoStageDetector): class CenterPoint(MVXTwoStageDetector):
"""Base class of Multi-modality VoxelNet.""" """Base class of Multi-modality VoxelNet.
Args:
pts_voxel_layer (dict, optional): Point cloud voxelization
layer. Defaults to None.
pts_voxel_encoder (dict, optional): Point voxelization
encoder layer. Defaults to None.
pts_middle_encoder (dict, optional): Middle encoder layer
of points cloud modality. Defaults to None.
pts_fusion_layer (dict, optional): Fusion layer.
Defaults to None.
img_backbone (dict, optional): Backbone of extracting
images feature. Defaults to None.
pts_backbone (dict, optional): Backbone of extracting
points features. Defaults to None.
img_neck (dict, optional): Neck of extracting
image features. Defaults to None.
pts_neck (dict, optional): Neck of extracting
points features. Defaults to None.
pts_bbox_head (dict, optional): Bboxes head of
point cloud modality. Defaults to None.
img_roi_head (dict, optional): RoI head of image
modality. Defaults to None.
img_rpn_head (dict, optional): RPN head of image
modality. Defaults to None.
train_cfg (dict, optional): Train config of model.
Defaults to None.
test_cfg (dict, optional): Train config of model.
Defaults to None.
init_cfg (dict, optional): Initialize config of
model. Defaults to None.
data_preprocessor (dict or ConfigDict, optional): The pre-process
config of :class:`Det3DDataPreprocessor`. Defaults to None.
"""
def __init__(self, def __init__(self,
pts_voxel_layer=None, pts_voxel_layer: Optional[dict] = None,
pts_voxel_encoder=None, pts_voxel_encoder: Optional[dict] = None,
pts_middle_encoder=None, pts_middle_encoder: Optional[dict] = None,
pts_fusion_layer=None, pts_fusion_layer: Optional[dict] = None,
img_backbone=None, img_backbone: Optional[dict] = None,
pts_backbone=None, pts_backbone: Optional[dict] = None,
img_neck=None, img_neck: Optional[dict] = None,
pts_neck=None, pts_neck: Optional[dict] = None,
pts_bbox_head=None, pts_bbox_head: Optional[dict] = None,
img_roi_head=None, img_roi_head: Optional[dict] = None,
img_rpn_head=None, img_rpn_head: Optional[dict] = None,
train_cfg=None, train_cfg: Optional[dict] = None,
test_cfg=None, test_cfg: Optional[dict] = None,
pretrained=None, init_cfg: Optional[dict] = None,
init_cfg=None): data_preprocessor: Optional[dict] = None,
**kwargs):
super(CenterPoint, super(CenterPoint,
self).__init__(pts_voxel_layer, pts_voxel_encoder, self).__init__(pts_voxel_layer, pts_voxel_encoder,
pts_middle_encoder, pts_fusion_layer, pts_middle_encoder, pts_fusion_layer,
img_backbone, pts_backbone, img_neck, pts_neck, img_backbone, pts_backbone, img_neck, pts_neck,
pts_bbox_head, img_roi_head, img_rpn_head, pts_bbox_head, img_roi_head, img_rpn_head,
train_cfg, test_cfg, pretrained, init_cfg) train_cfg, test_cfg, init_cfg, data_preprocessor,
**kwargs)
def extract_pts_feat(self, pts, img_feats, img_metas):
"""Extract features of points."""
if not self.with_pts_bbox:
return None
voxels, num_points, coors = self.voxelize(pts)
voxel_features = self.pts_voxel_encoder(voxels, num_points, coors)
batch_size = coors[-1, 0] + 1
x = self.pts_middle_encoder(voxel_features, coors, batch_size)
x = self.pts_backbone(x)
if self.with_pts_neck:
x = self.pts_neck(x)
return x
def forward_pts_train(self,
pts_feats,
gt_bboxes_3d,
gt_labels_3d,
img_metas,
gt_bboxes_ignore=None):
"""Forward function for point cloud branch.
Args:
pts_feats (list[torch.Tensor]): Features of point cloud branch
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth
boxes for each sample.
gt_labels_3d (list[torch.Tensor]): Ground truth labels for
boxes of each sampole
img_metas (list[dict]): Meta information of samples.
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
boxes to be ignored. Defaults to None.
Returns:
dict: Losses of each branch.
"""
outs = self.pts_bbox_head(pts_feats)
loss_inputs = [gt_bboxes_3d, gt_labels_3d, outs]
losses = self.pts_bbox_head.loss(*loss_inputs)
return losses
def simple_test_pts(self, x, img_metas, rescale=False):
"""Test function of point cloud branch."""
outs = self.pts_bbox_head(x)
bbox_list = self.pts_bbox_head.get_bboxes(
outs, img_metas, rescale=rescale)
bbox_results = [
bbox3d2result(bboxes, scores, labels)
for bboxes, scores, labels in bbox_list
]
return bbox_results
# TODO support this
def aug_test_pts(self, feats, img_metas, rescale=False): def aug_test_pts(self, feats, img_metas, rescale=False):
"""Test function of point cloud branch with augmentaiton. """Test function of point cloud branch with augmentaiton.
...@@ -107,6 +95,7 @@ class CenterPoint(MVXTwoStageDetector): ...@@ -107,6 +95,7 @@ class CenterPoint(MVXTwoStageDetector):
- scores_3d (torch.Tensor): Scores of predicted boxes. - scores_3d (torch.Tensor): Scores of predicted boxes.
- labels_3d (torch.Tensor): Labels of predicted boxes. - labels_3d (torch.Tensor): Labels of predicted boxes.
""" """
raise NotImplementedError
# only support aug_test for one sample # only support aug_test for one sample
outs_list = [] outs_list = []
for x, img_meta in zip(feats, img_metas): for x, img_meta in zip(feats, img_metas):
...@@ -186,7 +175,9 @@ class CenterPoint(MVXTwoStageDetector): ...@@ -186,7 +175,9 @@ class CenterPoint(MVXTwoStageDetector):
bbox_list[0][key] = bbox_list[0][key].to('cpu') bbox_list[0][key] = bbox_list[0][key].to('cpu')
return bbox_list[0] return bbox_list[0]
# TODO support this
def aug_test(self, points, img_metas, imgs=None, rescale=False): def aug_test(self, points, img_metas, imgs=None, rescale=False):
raise NotImplementedError
"""Test function with augmentaiton.""" """Test function with augmentaiton."""
img_feats, pts_feats = self.extract_feats(points, img_metas, imgs) img_feats, pts_feats = self.extract_feats(points, img_metas, imgs)
bbox_list = dict() bbox_list = dict()
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Sequence
import torch import torch
from mmcv.runner import force_fp32 from torch import Tensor
from torch.nn import functional as F from torch.nn import functional as F
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
...@@ -23,7 +25,6 @@ class DynamicMVXFasterRCNN(MVXTwoStageDetector): ...@@ -23,7 +25,6 @@ class DynamicMVXFasterRCNN(MVXTwoStageDetector):
super(DynamicMVXFasterRCNN, self).__init__(**kwargs) super(DynamicMVXFasterRCNN, self).__init__(**kwargs)
@torch.no_grad() @torch.no_grad()
@force_fp32()
def voxelize(self, points): def voxelize(self, points):
"""Apply dynamic voxelization to points. """Apply dynamic voxelization to points.
...@@ -46,13 +47,30 @@ class DynamicMVXFasterRCNN(MVXTwoStageDetector): ...@@ -46,13 +47,30 @@ class DynamicMVXFasterRCNN(MVXTwoStageDetector):
coors_batch = torch.cat(coors_batch, dim=0) coors_batch = torch.cat(coors_batch, dim=0)
return points, coors_batch return points, coors_batch
def extract_pts_feat(self, points, img_feats, img_metas): def extract_pts_feat(
"""Extract point features.""" self,
points: List[Tensor],
img_feats: Optional[Sequence[Tensor]] = None,
batch_input_metas: Optional[List[dict]] = None
) -> Sequence[Tensor]:
"""Extract features of points.
Args:
points (List[tensor]): Point cloud of multiple inputs.
img_feats (list[Tensor], tuple[tensor], optional): Features from
image backbone.
batch_input_metas (list[dict], optional): The meta information
of multiple samples. Defaults to True.
Returns:
Sequence[tensor]: points features of multiple inputs
from backbone or neck.
"""
if not self.with_pts_bbox: if not self.with_pts_bbox:
return None return None
voxels, coors = self.voxelize(points) voxels, coors = self.voxelize(points)
voxel_features, feature_coors = self.pts_voxel_encoder( voxel_features, feature_coors = self.pts_voxel_encoder(
voxels, coors, points, img_feats, img_metas) voxels, coors, points, img_feats, batch_input_metas)
batch_size = coors[-1, 0] + 1 batch_size = coors[-1, 0] + 1
x = self.pts_middle_encoder(voxel_features, feature_coors, batch_size) x = self.pts_middle_encoder(voxel_features, feature_coors, batch_size)
x = self.pts_backbone(x) x = self.pts_backbone(x)
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import warnings import copy
from os import path as osp from typing import Dict, List, Optional, Sequence, Tuple
import mmcv
import torch import torch
from mmcv.ops import Voxelization from mmcv.ops import Voxelization
from mmcv.parallel import DataContainer as DC from mmengine import InstanceData
from mmcv.runner import force_fp32 from torch import Tensor
from torch.nn import functional as F from torch.nn import functional as F
from mmdet3d.core import (Box3DMode, Coord3DMode, bbox3d2result, from mmdet3d.core import Det3DDataSample
merge_aug_bboxes_3d, show_result)
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
from mmdet.core import multi_apply
from .base import Base3DDetector from .base import Base3DDetector
@MODELS.register_module() @MODELS.register_module()
class MVXTwoStageDetector(Base3DDetector): class MVXTwoStageDetector(Base3DDetector):
"""Base class of Multi-modality VoxelNet.""" """Base class of Multi-modality VoxelNet.
Args:
pts_voxel_layer (dict, optional): Point cloud voxelization
layer. Defaults to None.
pts_voxel_encoder (dict, optional): Point voxelization
encoder layer. Defaults to None.
pts_middle_encoder (dict, optional): Middle encoder layer
of points cloud modality. Defaults to None.
pts_fusion_layer (dict, optional): Fusion layer.
Defaults to None.
img_backbone (dict, optional): Backbone of extracting
images feature. Defaults to None.
pts_backbone (dict, optional): Backbone of extracting
points features. Defaults to None.
img_neck (dict, optional): Neck of extracting
image features. Defaults to None.
pts_neck (dict, optional): Neck of extracting
points features. Defaults to None.
pts_bbox_head (dict, optional): Bboxes head of
point cloud modality. Defaults to None.
img_roi_head (dict, optional): RoI head of image
modality. Defaults to None.
img_rpn_head (dict, optional): RPN head of image
modality. Defaults to None.
train_cfg (dict, optional): Train config of model.
Defaults to None.
test_cfg (dict, optional): Train config of model.
Defaults to None.
init_cfg (dict, optional): Initialize config of
model. Defaults to None.
data_preprocessor (dict or ConfigDict, optional): The pre-process
config of :class:`Det3DDataPreprocessor`. Defaults to None.
"""
def __init__(self, def __init__(self,
pts_voxel_layer=None, pts_voxel_layer: Optional[dict] = None,
pts_voxel_encoder=None, pts_voxel_encoder: Optional[dict] = None,
pts_middle_encoder=None, pts_middle_encoder: Optional[dict] = None,
pts_fusion_layer=None, pts_fusion_layer: Optional[dict] = None,
img_backbone=None, img_backbone: Optional[dict] = None,
pts_backbone=None, pts_backbone: Optional[dict] = None,
img_neck=None, img_neck: Optional[dict] = None,
pts_neck=None, pts_neck: Optional[dict] = None,
pts_bbox_head=None, pts_bbox_head: Optional[dict] = None,
img_roi_head=None, img_roi_head: Optional[dict] = None,
img_rpn_head=None, img_rpn_head: Optional[dict] = None,
train_cfg=None, train_cfg: Optional[dict] = None,
test_cfg=None, test_cfg: Optional[dict] = None,
pretrained=None, init_cfg: Optional[dict] = None,
init_cfg=None): data_preprocessor: Optional[dict] = None,
super(MVXTwoStageDetector, self).__init__(init_cfg=init_cfg) **kwargs):
super(MVXTwoStageDetector, self).__init__(
init_cfg=init_cfg, data_preprocessor=data_preprocessor, **kwargs)
if pts_voxel_layer: if pts_voxel_layer:
self.pts_voxel_layer = Voxelization(**pts_voxel_layer) self.pts_voxel_layer = Voxelization(**pts_voxel_layer)
...@@ -69,35 +101,6 @@ class MVXTwoStageDetector(Base3DDetector): ...@@ -69,35 +101,6 @@ class MVXTwoStageDetector(Base3DDetector):
self.train_cfg = train_cfg self.train_cfg = train_cfg
self.test_cfg = test_cfg self.test_cfg = test_cfg
if pretrained is None:
img_pretrained = None
pts_pretrained = None
elif isinstance(pretrained, dict):
img_pretrained = pretrained.get('img', None)
pts_pretrained = pretrained.get('pts', None)
else:
raise ValueError(
f'pretrained should be a dict, got {type(pretrained)}')
if self.with_img_backbone:
if img_pretrained is not None:
warnings.warn('DeprecationWarning: pretrained is a deprecated '
'key, please consider using init_cfg.')
self.img_backbone.init_cfg = dict(
type='Pretrained', checkpoint=img_pretrained)
if self.with_img_roi_head:
if img_pretrained is not None:
warnings.warn('DeprecationWarning: pretrained is a deprecated '
'key, please consider using init_cfg.')
self.img_roi_head.init_cfg = dict(
type='Pretrained', checkpoint=img_pretrained)
if self.with_pts_backbone:
if pts_pretrained is not None:
warnings.warn('DeprecationWarning: pretrained is a deprecated '
'key, please consider using init_cfg')
self.pts_backbone.init_cfg = dict(
type='Pretrained', checkpoint=pts_pretrained)
@property @property
def with_img_shared_head(self): def with_img_shared_head(self):
"""bool: Whether the detector has a shared head in image branch.""" """bool: Whether the detector has a shared head in image branch."""
...@@ -164,12 +167,15 @@ class MVXTwoStageDetector(Base3DDetector): ...@@ -164,12 +167,15 @@ class MVXTwoStageDetector(Base3DDetector):
return hasattr(self, return hasattr(self,
'middle_encoder') and self.middle_encoder is not None 'middle_encoder') and self.middle_encoder is not None
def extract_img_feat(self, img, img_metas): def _forward(self):
pass
def extract_img_feat(self, img: Tensor, input_metas: List[dict]) -> dict:
"""Extract features of images.""" """Extract features of images."""
if self.with_img_backbone and img is not None: if self.with_img_backbone and img is not None:
input_shape = img.shape[-2:] input_shape = img.shape[-2:]
# update real input shape of each single img # update real input shape of each single img
for img_meta in img_metas: for img_meta in input_metas:
img_meta.update(input_shape=input_shape) img_meta.update(input_shape=input_shape)
if img.dim() == 5 and img.size(0) == 1: if img.dim() == 5 and img.size(0) == 1:
...@@ -184,13 +190,30 @@ class MVXTwoStageDetector(Base3DDetector): ...@@ -184,13 +190,30 @@ class MVXTwoStageDetector(Base3DDetector):
img_feats = self.img_neck(img_feats) img_feats = self.img_neck(img_feats)
return img_feats return img_feats
def extract_pts_feat(self, pts, img_feats, img_metas): def extract_pts_feat(
"""Extract features of points.""" self,
points: List[Tensor],
img_feats: Optional[Sequence[Tensor]] = None,
batch_input_metas: Optional[List[dict]] = None
) -> Sequence[Tensor]:
"""Extract features of points.
Args:
points (List[tensor]): Point cloud of multiple inputs.
img_feats (list[Tensor], tuple[tensor], optional): Features from
image backbone.
batch_input_metas (list[dict], optional): The meta information
of multiple samples. Defaults to True.
Returns:
Sequence[tensor]: points features of multiple inputs
from backbone or neck.
"""
if not self.with_pts_bbox: if not self.with_pts_bbox:
return None return None
voxels, num_points, coors = self.voxelize(pts) voxels, num_points, coors = self.voxelize(points)
voxel_features = self.pts_voxel_encoder(voxels, num_points, coors, voxel_features = self.pts_voxel_encoder(voxels, num_points, coors,
img_feats, img_metas) img_feats, batch_input_metas)
batch_size = coors[-1, 0] + 1 batch_size = coors[-1, 0] + 1
x = self.pts_middle_encoder(voxel_features, coors, batch_size) x = self.pts_middle_encoder(voxel_features, coors, batch_size)
x = self.pts_backbone(x) x = self.pts_backbone(x)
...@@ -198,15 +221,32 @@ class MVXTwoStageDetector(Base3DDetector): ...@@ -198,15 +221,32 @@ class MVXTwoStageDetector(Base3DDetector):
x = self.pts_neck(x) x = self.pts_neck(x)
return x return x
def extract_feat(self, points, img, img_metas): def extract_feat(self, batch_inputs_dict: List[Tensor],
"""Extract features from images and points.""" batch_input_metas: List[dict]) -> tuple:
img_feats = self.extract_img_feat(img, img_metas) """Extract features from images and points.
pts_feats = self.extract_pts_feat(points, img_feats, img_metas)
Args:
batch_inputs_dict (dict): Dict of batch inputs. It
contains
- points (List[tensor]): Point cloud of multiple inputs.
- imgs (tensor): Image tensor with shape (B, C, H, W).
batch_input_metas (list[dict]): Meta information of multiple inputs
in a batch.
Returns:
tuple: Two elements in tuple arrange as
image features and point cloud features.
"""
points = batch_inputs_dict['points']
imgs = batch_inputs_dict['imgs']
img_feats = self.extract_img_feat(imgs, batch_input_metas)
pts_feats = self.extract_pts_feat(
points, img_feats=img_feats, batch_input_metas=batch_input_metas)
return (img_feats, pts_feats) return (img_feats, pts_feats)
@torch.no_grad() @torch.no_grad()
@force_fp32() def voxelize(self, points: List[Tensor]) -> Tuple:
def voxelize(self, points):
"""Apply dynamic voxelization to points. """Apply dynamic voxelization to points.
Args: Args:
...@@ -231,95 +271,41 @@ class MVXTwoStageDetector(Base3DDetector): ...@@ -231,95 +271,41 @@ class MVXTwoStageDetector(Base3DDetector):
coors_batch = torch.cat(coors_batch, dim=0) coors_batch = torch.cat(coors_batch, dim=0)
return voxels, num_points, coors_batch return voxels, num_points, coors_batch
def forward_train(self, def loss(self, batch_inputs_dict: Dict[List, torch.Tensor],
points=None, batch_data_samples: List[Det3DDataSample],
img_metas=None, **kwargs) -> List[Det3DDataSample]:
gt_bboxes_3d=None, """
gt_labels_3d=None,
gt_labels=None,
gt_bboxes=None,
img=None,
proposals=None,
gt_bboxes_ignore=None):
"""Forward training function.
Args: Args:
points (list[torch.Tensor], optional): Points of each sample. batch_inputs_dict (dict): The model input dict which include
Defaults to None. 'points' and `imgs` keys.
img_metas (list[dict], optional): Meta information of each sample.
Defaults to None. - points (list[torch.Tensor]): Point cloud of each sample.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`], optional): - imgs (torch.Tensor): Tensor of batch images, has shape
Ground truth 3D boxes. Defaults to None. (B, C, H ,W)
gt_labels_3d (list[torch.Tensor], optional): Ground truth labels batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
of 3D boxes. Defaults to None. Samples. It usually includes information such as
gt_labels (list[torch.Tensor], optional): Ground truth labels `gt_instance_3d`, .
of 2D boxes in images. Defaults to None.
gt_bboxes (list[torch.Tensor], optional): Ground truth 2D boxes in
images. Defaults to None.
img (torch.Tensor, optional): Images of each sample with shape
(N, C, H, W). Defaults to None.
proposals ([list[torch.Tensor], optional): Predicted proposals
used for training Fast RCNN. Defaults to None.
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
2D boxes in images to be ignored. Defaults to None.
Returns: Returns:
dict: Losses of different branches. dict[str, Tensor]: A dictionary of loss components.
""" """
img_feats, pts_feats = self.extract_feat(
points, img=img, img_metas=img_metas) batch_input_metas = [item.metainfo for item in batch_data_samples]
img_feats, pts_feats = self.extract_feat(batch_inputs_dict,
batch_input_metas)
losses = dict() losses = dict()
if pts_feats: if pts_feats:
losses_pts = self.forward_pts_train(pts_feats, gt_bboxes_3d, losses_pts = self.pts_bbox_head.loss(pts_feats, batch_data_samples,
gt_labels_3d, img_metas, **kwargs)
gt_bboxes_ignore)
losses.update(losses_pts) losses.update(losses_pts)
if img_feats: if img_feats:
losses_img = self.forward_img_train( losses_img = self.loss_imgs(img_feats, batch_data_samples)
img_feats,
img_metas=img_metas,
gt_bboxes=gt_bboxes,
gt_labels=gt_labels,
gt_bboxes_ignore=gt_bboxes_ignore,
proposals=proposals)
losses.update(losses_img) losses.update(losses_img)
return losses return losses
def forward_pts_train(self, def loss_imgs(self, x: List[Tensor],
pts_feats, batch_data_samples: List[Det3DDataSample], **kwargs):
gt_bboxes_3d,
gt_labels_3d,
img_metas,
gt_bboxes_ignore=None):
"""Forward function for point cloud branch.
Args:
pts_feats (list[torch.Tensor]): Features of point cloud branch
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth
boxes for each sample.
gt_labels_3d (list[torch.Tensor]): Ground truth labels for
boxes of each sampole
img_metas (list[dict]): Meta information of samples.
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
boxes to be ignored. Defaults to None.
Returns:
dict: Losses of each branch.
"""
outs = self.pts_bbox_head(pts_feats)
loss_inputs = outs + (gt_bboxes_3d, gt_labels_3d, img_metas)
losses = self.pts_bbox_head.loss(
*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
return losses
def forward_img_train(self,
x,
img_metas,
gt_bboxes,
gt_labels,
gt_bboxes_ignore=None,
proposals=None,
**kwargs):
"""Forward function for image branch. """Forward function for image branch.
This function works similar to the forward function of Faster R-CNN. This function works similar to the forward function of Faster R-CNN.
...@@ -327,14 +313,9 @@ class MVXTwoStageDetector(Base3DDetector): ...@@ -327,14 +313,9 @@ class MVXTwoStageDetector(Base3DDetector):
Args: Args:
x (list[torch.Tensor]): Image features of shape (B, C, H, W) x (list[torch.Tensor]): Image features of shape (B, C, H, W)
of multiple levels. of multiple levels.
img_metas (list[dict]): Meta information of images. batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
gt_bboxes (list[torch.Tensor]): Ground truth boxes of each image Samples. It usually includes information such as
sample. `gt_instance_3d`, .
gt_labels (list[torch.Tensor]): Ground truth labels of boxes.
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
boxes to be ignored. Defaults to None.
proposals (list[torch.Tensor], optional): Proposals of each sample.
Defaults to None.
Returns: Returns:
dict: Losses of each branch. dict: Losses of each branch.
...@@ -342,158 +323,109 @@ class MVXTwoStageDetector(Base3DDetector): ...@@ -342,158 +323,109 @@ class MVXTwoStageDetector(Base3DDetector):
losses = dict() losses = dict()
# RPN forward and loss # RPN forward and loss
if self.with_img_rpn: if self.with_img_rpn:
rpn_outs = self.img_rpn_head(x) proposal_cfg = self.test_cfg.rpn
rpn_loss_inputs = rpn_outs + (gt_bboxes, img_metas, rpn_data_samples = copy.deepcopy(batch_data_samples)
self.train_cfg.img_rpn) # set cat_id of gt_labels to 0 in RPN
rpn_losses = self.img_rpn_head.loss( for data_sample in rpn_data_samples:
*rpn_loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore) data_sample.gt_instances.labels = \
torch.zeros_like(data_sample.gt_instances.labels)
rpn_losses, rpn_results_list = self.img_rpn_head.loss_and_predict(
x, rpn_data_samples, proposal_cfg=proposal_cfg, **kwargs)
# avoid get same name with roi_head loss
keys = rpn_losses.keys()
for key in keys:
if 'loss' in key and 'rpn' not in key:
rpn_losses[f'rpn_{key}'] = rpn_losses.pop(key)
losses.update(rpn_losses) losses.update(rpn_losses)
proposal_cfg = self.train_cfg.get('img_rpn_proposal',
self.test_cfg.img_rpn)
proposal_inputs = rpn_outs + (img_metas, proposal_cfg)
proposal_list = self.img_rpn_head.get_bboxes(*proposal_inputs)
else: else:
proposal_list = proposals if 'proposals' in batch_data_samples[0]:
# use pre-defined proposals in InstanceData
# for the second stage
# to extract ROI features.
rpn_results_list = [
data_sample.proposals for data_sample in batch_data_samples
]
else:
rpn_results_list = None
# bbox head forward and loss # bbox head forward and loss
if self.with_img_bbox: if self.with_img_bbox:
# bbox head forward and loss roi_losses = self.img_roi_head.loss(x, rpn_results_list,
img_roi_losses = self.img_roi_head.forward_train( batch_data_samples, **kwargs)
x, img_metas, proposal_list, gt_bboxes, gt_labels, losses.update(roi_losses)
gt_bboxes_ignore, **kwargs)
losses.update(img_roi_losses)
return losses return losses
def simple_test_img(self, x, img_metas, proposals=None, rescale=False): def predict_imgs(self,
"""Test without augmentation.""" x: List[Tensor],
if proposals is None: batch_data_samples: List[Det3DDataSample],
proposal_list = self.simple_test_rpn(x, img_metas, rescale: bool = True,
self.test_cfg.img_rpn) **kwargs) -> InstanceData:
else: """Predict results from a batch of inputs and data samples with post-
proposal_list = proposals processing.
return self.img_roi_head.simple_test(
x, proposal_list, img_metas, rescale=rescale)
def simple_test_rpn(self, x, img_metas, rpn_test_cfg):
"""RPN test function."""
rpn_outs = self.img_rpn_head(x)
proposal_inputs = rpn_outs + (img_metas, rpn_test_cfg)
proposal_list = self.img_rpn_head.get_bboxes(*proposal_inputs)
return proposal_list
def simple_test_pts(self, x, img_metas, rescale=False):
"""Test function of point cloud branch."""
outs = self.pts_bbox_head(x)
bbox_list = self.pts_bbox_head.get_bboxes(
*outs, img_metas, rescale=rescale)
bbox_results = [
bbox3d2result(bboxes, scores, labels)
for bboxes, scores, labels in bbox_list
]
return bbox_results
def simple_test(self, points, img_metas, img=None, rescale=False):
"""Test function without augmentaiton."""
img_feats, pts_feats = self.extract_feat(
points, img=img, img_metas=img_metas)
bbox_list = [dict() for i in range(len(img_metas))]
if pts_feats and self.with_pts_bbox:
bbox_pts = self.simple_test_pts(
pts_feats, img_metas, rescale=rescale)
for result_dict, pts_bbox in zip(bbox_list, bbox_pts):
result_dict['pts_bbox'] = pts_bbox
if img_feats and self.with_img_bbox:
bbox_img = self.simple_test_img(
img_feats, img_metas, rescale=rescale)
for result_dict, img_bbox in zip(bbox_list, bbox_img):
result_dict['img_bbox'] = img_bbox
return bbox_list
def aug_test(self, points, img_metas, imgs=None, rescale=False): Args:
"""Test function with augmentaiton.""" x (List[Tensor]): Image features from FPN.
img_feats, pts_feats = self.extract_feats(points, img_metas, imgs) batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
rescale (bool): Whether to rescale the results.
Defaults to True.
"""
bbox_list = dict() if batch_data_samples[0].get('proposals', None) is None:
if pts_feats and self.with_pts_bbox: rpn_results_list = self.img_rpn_head.predict(
bbox_pts = self.aug_test_pts(pts_feats, img_metas, rescale) x, batch_data_samples, rescale=False)
bbox_list.update(pts_bbox=bbox_pts) else:
return [bbox_list] rpn_results_list = [
data_sample.proposals for data_sample in batch_data_samples
def extract_feats(self, points, img_metas, imgs=None):
"""Extract point and image features of multiple samples."""
if imgs is None:
imgs = [None] * len(img_metas)
img_feats, pts_feats = multi_apply(self.extract_feat, points, imgs,
img_metas)
return img_feats, pts_feats
def aug_test_pts(self, feats, img_metas, rescale=False):
"""Test function of point cloud branch with augmentaiton."""
# only support aug_test for one sample
aug_bboxes = []
for x, img_meta in zip(feats, img_metas):
outs = self.pts_bbox_head(x)
bbox_list = self.pts_bbox_head.get_bboxes(
*outs, img_meta, rescale=rescale)
bbox_list = [
dict(boxes_3d=bboxes, scores_3d=scores, labels_3d=labels)
for bboxes, scores, labels in bbox_list
] ]
aug_bboxes.append(bbox_list[0]) results_list = self.img_roi_head.predict(
x, rpn_results_list, batch_data_samples, rescale=rescale, **kwargs)
return results_list
# after merging, bboxes will be rescaled to the original image size def predict(self, batch_inputs_dict: Dict[str, Optional[Tensor]],
merged_bboxes = merge_aug_bboxes_3d(aug_bboxes, img_metas, batch_data_samples: List[Det3DDataSample],
self.pts_bbox_head.test_cfg) **kwargs) -> List[Det3DDataSample]:
return merged_bboxes """Forward of testing.
def show_results(self, data, result, out_dir):
"""Results visualization.
Args: Args:
data (dict): Input points and the information of the sample. batch_inputs_dict (dict): The model input dict which include
result (dict): Prediction results. 'points' keys.
out_dir (str): Output directory of visualization result.
- points (list[torch.Tensor]): Point cloud of each sample.
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance_3d`.
Returns:
list[:obj:`Det3DDataSample`]: Detection results of the
input sample. Each Det3DDataSample usually contain
'pred_instances_3d'. And the ``pred_instances_3d`` usually
contains following keys.
- scores_3d (Tensor): Classification scores, has a shape
(num_instances, )
- labels_3d (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes_3d (:obj:`BaseInstance3DBoxes`): Prediction of bboxes,
contains a tensor with shape (num_instances, 7).
""" """
for batch_id in range(len(result)): batch_input_metas = [item.metainfo for item in batch_data_samples]
if isinstance(data['points'][0], DC): img_feats, pts_feats = self.extract_feat(batch_inputs_dict,
points = data['points'][0]._data[0][batch_id].numpy() batch_input_metas)
elif mmcv.is_list_of(data['points'][0], torch.Tensor): if pts_feats and self.with_pts_bbox:
points = data['points'][0][batch_id] results_list_3d = self.pts_bbox_head.predict(
else: pts_feats, batch_data_samples, **kwargs)
ValueError(f"Unsupported data type {type(data['points'][0])} " else:
f'for visualization!') results_list_3d = None
if isinstance(data['img_metas'][0], DC):
pts_filename = data['img_metas'][0]._data[0][batch_id][ if img_feats and self.with_img_bbox:
'pts_filename'] # TODO check this for camera modality
box_mode_3d = data['img_metas'][0]._data[0][batch_id][ results_list_2d = self.predict_imgs(img_feats, batch_data_samples,
'box_mode_3d'] **kwargs)
elif mmcv.is_list_of(data['img_metas'][0], dict): else:
pts_filename = data['img_metas'][0][batch_id]['pts_filename'] results_list_2d = None
box_mode_3d = data['img_metas'][0][batch_id]['box_mode_3d']
else: detsamples = self.convert_to_datasample(results_list_3d,
ValueError( results_list_2d)
f"Unsupported data type {type(data['img_metas'][0])} " return detsamples
f'for visualization!')
file_name = osp.split(pts_filename)[-1].split('.')[0]
assert out_dir is not None, 'Expect out_dir, got none.'
inds = result[batch_id]['pts_bbox']['scores_3d'] > 0.1
pred_bboxes = result[batch_id]['pts_bbox']['boxes_3d'][inds]
# for now we convert points and bbox into depth mode
if (box_mode_3d == Box3DMode.CAM) or (box_mode_3d
== Box3DMode.LIDAR):
points = Coord3DMode.convert_point(points, Coord3DMode.LIDAR,
Coord3DMode.DEPTH)
pred_bboxes = Box3DMode.convert(pred_bboxes, box_mode_3d,
Box3DMode.DEPTH)
elif box_mode_3d != Box3DMode.DEPTH:
ValueError(
f'Unsupported box_mode_3d {box_mode_3d} for conversion!')
pred_bboxes = pred_bboxes.tensor.cpu().numpy()
show_result(points, None, pred_bboxes, out_dir, file_name)
...@@ -3,7 +3,7 @@ import torch ...@@ -3,7 +3,7 @@ import torch
from mmcv.cnn import build_norm_layer from mmcv.cnn import build_norm_layer
from mmcv.ops import DynamicScatter from mmcv.ops import DynamicScatter
from mmcv.runner import force_fp32 from mmcv.runner import force_fp32
from torch import nn from torch import Tensor, nn
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
from .. import builder from .. import builder
...@@ -20,13 +20,14 @@ class HardSimpleVFE(nn.Module): ...@@ -20,13 +20,14 @@ class HardSimpleVFE(nn.Module):
num_features (int, optional): Number of features to use. Default: 4. num_features (int, optional): Number of features to use. Default: 4.
""" """
def __init__(self, num_features=4): def __init__(self, num_features: int = 4) -> None:
super(HardSimpleVFE, self).__init__() super(HardSimpleVFE, self).__init__()
self.num_features = num_features self.num_features = num_features
self.fp16_enabled = False self.fp16_enabled = False
@force_fp32(out_fp16=True) @force_fp32(out_fp16=True)
def forward(self, features, num_points, coors): def forward(self, features: Tensor, num_points: Tensor, coors: Tensor,
*args, **kwargs) -> Tensor:
"""Forward function. """Forward function.
Args: Args:
...@@ -66,7 +67,7 @@ class DynamicSimpleVFE(nn.Module): ...@@ -66,7 +67,7 @@ class DynamicSimpleVFE(nn.Module):
@torch.no_grad() @torch.no_grad()
@force_fp32(out_fp16=True) @force_fp32(out_fp16=True)
def forward(self, features, coors): def forward(self, features, coors, *args, **kwargs):
"""Forward function. """Forward function.
Args: Args:
...@@ -218,13 +219,14 @@ class DynamicVFE(nn.Module): ...@@ -218,13 +219,14 @@ class DynamicVFE(nn.Module):
center_per_point = voxel_mean[voxel_inds, ...] center_per_point = voxel_mean[voxel_inds, ...]
return center_per_point return center_per_point
@force_fp32(out_fp16=True)
def forward(self, def forward(self,
features, features,
coors, coors,
points=None, points=None,
img_feats=None, img_feats=None,
img_metas=None): img_metas=None,
*args,
**kwargs):
"""Forward functions. """Forward functions.
Args: Args:
...@@ -390,7 +392,9 @@ class HardVFE(nn.Module): ...@@ -390,7 +392,9 @@ class HardVFE(nn.Module):
num_points, num_points,
coors, coors,
img_feats=None, img_feats=None,
img_metas=None): img_metas=None,
*args,
**kwargs):
"""Forward functions. """Forward functions.
Args: Args:
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import pytest import pytest
import torch import torch
from mmengine import InstanceData
from mmdet3d.core.bbox.assigners import MaxIoUAssigner from mmdet3d.core.bbox.assigners import Max3DIoUAssigner
from mmdet3d.core.bbox.samplers import IoUNegPiecewiseSampler from mmdet3d.core.bbox.samplers import IoUNegPiecewiseSampler
def test_iou_piecewise_sampler(): def test_iou_piecewise_sampler():
if not torch.cuda.is_available(): if not torch.cuda.is_available():
pytest.skip() pytest.skip()
assigner = MaxIoUAssigner( assigner = Max3DIoUAssigner(
pos_iou_thr=0.55, pos_iou_thr=0.55,
neg_iou_thr=0.55, neg_iou_thr=0.55,
min_pos_iou=0.55, min_pos_iou=0.55,
...@@ -27,7 +28,13 @@ def test_iou_piecewise_sampler(): ...@@ -27,7 +28,13 @@ def test_iou_piecewise_sampler():
[[0, 0, 0, 10, 10, 9, 0.2], [5, 10, 10, 20, 20, 15, 0.6]], [[0, 0, 0, 10, 10, 9, 0.2], [5, 10, 10, 20, 20, 15, 0.6]],
dtype=torch.float32).cuda() dtype=torch.float32).cuda()
gt_labels = torch.tensor([1, 1], dtype=torch.int64).cuda() gt_labels = torch.tensor([1, 1], dtype=torch.int64).cuda()
assign_result = assigner.assign(bboxes, gt_bboxes, gt_labels=gt_labels) gt_instanses = InstanceData()
gt_instanses.bboxes_3d = gt_bboxes
gt_instanses.labels_3d = gt_labels
pred_instaces = InstanceData()
pred_instaces.priors = bboxes
assign_result = assigner.assign(pred_instaces, gt_instanses)
sampler = IoUNegPiecewiseSampler( sampler = IoUNegPiecewiseSampler(
num=10, num=10,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment