Commit c9ad3605 authored by jshilong's avatar jshilong Committed by ChaimZhu
Browse files

[Refactor]New version VoteNet

parent db44cc50
# dataset settings # dataset settings
dataset_type = 'ScanNetDataset' dataset_type = 'ScanNetDataset'
data_root = './data/scannet/' data_root = './data/scannet/'
class_names = ('cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window',
metainfo = dict(
CLASSES=('cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window',
'bookshelf', 'picture', 'counter', 'desk', 'curtain', 'bookshelf', 'picture', 'counter', 'desk', 'curtain',
'refrigerator', 'showercurtrain', 'toilet', 'sink', 'bathtub', 'refrigerator', 'showercurtrain', 'toilet', 'sink', 'bathtub',
'garbagebin') 'garbagebin'))
train_pipeline = [ train_pipeline = [
dict( dict(
type='LoadPointsFromFile', type='LoadPointsFromFile',
...@@ -35,9 +37,8 @@ train_pipeline = [ ...@@ -35,9 +37,8 @@ train_pipeline = [
rot_range=[-0.087266, 0.087266], rot_range=[-0.087266, 0.087266],
scale_ratio_range=[1.0, 1.0], scale_ratio_range=[1.0, 1.0],
shift_height=True), shift_height=True),
dict(type='DefaultFormatBundle3D', class_names=class_names),
dict( dict(
type='Collect3D', type='Pack3DDetInputs',
keys=[ keys=[
'points', 'gt_bboxes_3d', 'gt_labels_3d', 'pts_semantic_mask', 'points', 'gt_bboxes_3d', 'gt_labels_3d', 'pts_semantic_mask',
'pts_instance_mask' 'pts_instance_mask'
...@@ -68,61 +69,51 @@ test_pipeline = [ ...@@ -68,61 +69,51 @@ test_pipeline = [
flip_ratio_bev_horizontal=0.5, flip_ratio_bev_horizontal=0.5,
flip_ratio_bev_vertical=0.5), flip_ratio_bev_vertical=0.5),
dict(type='PointSample', num_points=40000), dict(type='PointSample', num_points=40000),
dict( ]),
type='DefaultFormatBundle3D', dict(type='Pack3DDetInputs', keys=['points'])
class_names=class_names,
with_label=False),
dict(type='Collect3D', keys=['points'])
])
]
# construct a pipeline for data and gt loading in show function
# please keep its loading function consistent with test_pipeline (e.g. client)
eval_pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='DEPTH',
shift_height=False,
load_dim=6,
use_dim=[0, 1, 2]),
dict(type='GlobalAlignment', rotation_axis=2),
dict(
type='DefaultFormatBundle3D',
class_names=class_names,
with_label=False),
dict(type='Collect3D', keys=['points'])
] ]
data = dict( train_dataloader = dict(
samples_per_gpu=8, batch_size=8,
workers_per_gpu=4, num_workers=4,
train=dict( sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type='RepeatDataset', type='RepeatDataset',
times=5, times=5,
dataset=dict( dataset=dict(
type=dataset_type, type=dataset_type,
data_root=data_root, data_root=data_root,
ann_file=data_root + 'scannet_infos_train.pkl', ann_file='scannet_infos_train.pkl',
pipeline=train_pipeline, pipeline=train_pipeline,
filter_empty_gt=False, filter_empty_gt=False,
classes=class_names, metainfo=metainfo,
# we use box_type_3d='LiDAR' in kitti and nuscenes dataset # we use box_type_3d='LiDAR' in kitti and nuscenes dataset
# and box_type_3d='Depth' in sunrgbd and scannet dataset. # and box_type_3d='Depth' in sunrgbd and scannet dataset.
box_type_3d='Depth')), box_type_3d='Depth')))
val=dict(
val_dataloader = dict(
batch_size=1,
num_workers=1,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type, type=dataset_type,
data_root=data_root, data_root=data_root,
ann_file=data_root + 'scannet_infos_val.pkl', ann_file='scannet_infos_val.pkl',
pipeline=test_pipeline, pipeline=test_pipeline,
classes=class_names, metainfo=metainfo,
test_mode=True, test_mode=True,
box_type_3d='Depth'), box_type_3d='Depth'))
test=dict( test_dataloader = dict(
batch_size=1,
num_workers=1,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type, type=dataset_type,
data_root=data_root, data_root=data_root,
ann_file=data_root + 'scannet_infos_val.pkl', ann_file='scannet_infos_val.pkl',
pipeline=test_pipeline, pipeline=test_pipeline,
classes=class_names, metainfo=metainfo,
test_mode=True, test_mode=True,
box_type_3d='Depth')) box_type_3d='Depth'))
val_evaluator = dict(type='IndoorMetric')
evaluation = dict(pipeline=eval_pipeline) test_evaluator = val_evaluator
default_scope = 'mmdet3d' default_scope = 'mmdet3d'
default_hooks = dict( default_hooks = dict(
optimizer=dict(type='OptimizerHook', grad_clip=None),
timer=dict(type='IterTimerHook'), timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook', interval=50), logger=dict(type='LoggerHook', interval=50),
param_scheduler=dict(type='ParamSchedulerHook'), param_scheduler=dict(type='ParamSchedulerHook'),
......
model = dict( model = dict(
type='VoteNet', type='VoteNet',
data_preprocessor=dict(type='Det3DDataPreprocessor'),
backbone=dict( backbone=dict(
type='PointNet2SASSG', type='PointNet2SASSG',
in_channels=4, in_channels=4,
...@@ -40,10 +41,8 @@ model = dict( ...@@ -40,10 +41,8 @@ model = dict(
normalize_xyz=True), normalize_xyz=True),
pred_layer_cfg=dict( pred_layer_cfg=dict(
in_channels=128, shared_conv_channels=(128, 128), bias=True), in_channels=128, shared_conv_channels=(128, 128), bias=True),
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
objectness_loss=dict( objectness_loss=dict(
type='CrossEntropyLoss', type='mmdet.CrossEntropyLoss',
class_weight=[0.2, 0.8], class_weight=[0.2, 0.8],
reduction='sum', reduction='sum',
loss_weight=5.0), loss_weight=5.0),
...@@ -54,20 +53,21 @@ model = dict( ...@@ -54,20 +53,21 @@ model = dict(
loss_src_weight=10.0, loss_src_weight=10.0,
loss_dst_weight=10.0), loss_dst_weight=10.0),
dir_class_loss=dict( dir_class_loss=dict(
type='CrossEntropyLoss', reduction='sum', loss_weight=1.0), type='mmdet.CrossEntropyLoss', reduction='sum', loss_weight=1.0),
dir_res_loss=dict( dir_res_loss=dict(
type='SmoothL1Loss', reduction='sum', loss_weight=10.0), type='mmdet.SmoothL1Loss', reduction='sum', loss_weight=10.0),
size_class_loss=dict( size_class_loss=dict(
type='CrossEntropyLoss', reduction='sum', loss_weight=1.0), type='mmdet.CrossEntropyLoss', reduction='sum', loss_weight=1.0),
size_res_loss=dict( size_res_loss=dict(
type='SmoothL1Loss', reduction='sum', loss_weight=10.0 / 3.0), type='mmdet.SmoothL1Loss', reduction='sum',
loss_weight=10.0 / 3.0),
semantic_loss=dict( semantic_loss=dict(
type='CrossEntropyLoss', reduction='sum', loss_weight=1.0)), type='mmdet.CrossEntropyLoss', reduction='sum', loss_weight=1.0)),
# model training and testing settings # model training and testing settings
train_cfg=dict( train_cfg=dict(
pos_distance_thr=0.3, neg_distance_thr=0.6, sample_mod='vote'), pos_distance_thr=0.3, neg_distance_thr=0.6, sample_mode='vote'),
test_cfg=dict( test_cfg=dict(
sample_mod='seed', sample_mode='seed',
nms_thr=0.25, nms_thr=0.25,
score_thr=0.05, score_thr=0.05,
per_class_proposal=True)) per_class_proposal=True))
...@@ -3,22 +3,47 @@ ...@@ -3,22 +3,47 @@
# interval to be 20. Please change the interval accordingly if you do not # interval to be 20. Please change the interval accordingly if you do not
# use a default schedule. # use a default schedule.
# optimizer # optimizer
lr = 1e-4
iter_num_in_epoch = 3712
# This schedule is mainly used by models on nuScenes dataset # This schedule is mainly used by models on nuScenes dataset
optimizer = dict(type='AdamW', lr=1e-4, weight_decay=0.01)
# max_norm=10 is better for SECOND # max_norm=10 is better for SECOND
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2)) optim_wrapper = dict(
lr_config = dict( type='OptimWrapper',
policy='cyclic', optimizer=dict(type='AdamW', lr=lr, weight_decay=0.01),
target_ratio=(10, 1e-4), clip_grad=dict(max_norm=35, norm_type=2))
cyclic_times=1, # learning rate
step_ratio_up=0.4, param_scheduler = [
) dict(
momentum_config = dict( type='CosineAnnealingLR',
policy='cyclic', T_max=8 * iter_num_in_epoch,
target_ratio=(0.85 / 0.95, 1), eta_min=lr * 10,
cyclic_times=1, by_epoch=False,
step_ratio_up=0.4, begin=0,
) end=8 * iter_num_in_epoch),
dict(
type='CosineAnnealingLR',
T_max=12 * iter_num_in_epoch,
eta_min=lr * 1e-4,
by_epoch=False,
begin=8 * iter_num_in_epoch,
end=20 * iter_num_in_epoch),
dict(
type='CosineAnnealingBetas',
T_max=8 * iter_num_in_epoch,
eta_min=0.85 / 0.95,
by_epoch=False,
begin=0,
end=8 * iter_num_in_epoch),
dict(
type='CosineAnnealingBetas',
T_max=12 * iter_num_in_epoch,
eta_min=1,
by_epoch=False,
begin=8 * iter_num_in_epoch,
end=20 * iter_num_in_epoch)
]
# runtime settings # runtime settings
runner = dict(type='EpochBasedRunner', max_epochs=20) train_cfg = dict(by_epoch=True, max_epochs=20)
val_cfg = dict(interval=1)
test_cfg = dict()
# The schedule is usually used by models trained on KITTI dataset # The schedule is usually used by models trained on KITTI dataset
# The learning rate set in the cyclic schedule is the initial learning rate # The learning rate set in the cyclic schedule is the initial learning rate
# rather than the max learning rate. Since the target_ratio is (10, 1e-4), # rather than the max learning rate. Since the target_ratio is (10, 1e-4),
# the learning rate will change from 0.0018 to 0.018, than go to 0.0018*1e-4 # the learning rate will change from 0.0018 to 0.018, than go to 0.0018*1e-4
lr = 0.0018 lr = 0.0018
iter_num_in_epoch = 3712
# The optimizer follows the setting in SECOND.Pytorch, but here we use # The optimizer follows the setting in SECOND.Pytorch, but here we use
# the official AdamW optimizer implemented by PyTorch. # the official AdamW optimizer implemented by PyTorch.
optimizer = dict(type='AdamW', lr=lr, betas=(0.95, 0.99), weight_decay=0.01) optim_wrapper = dict(
optimizer_config = dict(grad_clip=dict(max_norm=10, norm_type=2)) type='OptimWrapper',
# We use cyclic learning rate and momentum schedule following SECOND.Pytorch optimizer=dict(type='AdamW', lr=lr, betas=(0.95, 0.99), weight_decay=0.01),
# https://github.com/traveller59/second.pytorch/blob/3aba19c9688274f75ebb5e576f65cfe54773c021/torchplus/train/learning_schedules_fastai.py#L69 # noqa clip_grad=dict(max_norm=10, norm_type=2))
# We implement them in mmcv, for more details, please refer to # learning rate
# https://github.com/open-mmlab/mmcv/blob/f48241a65aebfe07db122e9db320c31b685dc674/mmcv/runner/hooks/lr_updater.py#L327 # noqa param_scheduler = [
# https://github.com/open-mmlab/mmcv/blob/f48241a65aebfe07db122e9db320c31b685dc674/mmcv/runner/hooks/momentum_updater.py#L130 # noqa dict(
lr_config = dict( type='CosineAnnealingLR',
policy='cyclic', T_max=16 * iter_num_in_epoch,
target_ratio=(10, 1e-4), eta_min=lr * 10,
cyclic_times=1, by_epoch=False,
step_ratio_up=0.4, begin=0,
) end=16 * iter_num_in_epoch),
momentum_config = dict( dict(
policy='cyclic', type='CosineAnnealingLR',
target_ratio=(0.85 / 0.95, 1), T_max=24 * iter_num_in_epoch,
cyclic_times=1, eta_min=lr * 1e-4,
step_ratio_up=0.4, by_epoch=False,
) begin=16 * iter_num_in_epoch,
end=40 * iter_num_in_epoch),
dict(
type='CosineAnnealingBetas',
T_max=16 * iter_num_in_epoch,
eta_min=0.85 / 0.95,
by_epoch=False,
begin=0,
end=16 * iter_num_in_epoch),
dict(
type='CosineAnnealingBetas',
T_max=24 * iter_num_in_epoch,
eta_min=1,
by_epoch=False,
begin=16 * iter_num_in_epoch,
end=40 * iter_num_in_epoch)
]
# Runtime settings,training schedule for 40e
# Although the max_epochs is 40, this schedule is usually used we # Although the max_epochs is 40, this schedule is usually used we
# RepeatDataset with repeat ratio N, thus the actual max epoch # RepeatDataset with repeat ratio N, thus the actual max epoch
# number could be Nx40 # number could be Nx40
runner = dict(type='EpochBasedRunner', max_epochs=40) train_cfg = dict(by_epoch=True, max_epochs=40)
val_cfg = dict(interval=1)
test_cfg = dict()
...@@ -2,8 +2,24 @@ ...@@ -2,8 +2,24 @@
# This schedule is mainly used by models on indoor dataset, # This schedule is mainly used by models on indoor dataset,
# e.g., VoteNet on SUNRGBD and ScanNet # e.g., VoteNet on SUNRGBD and ScanNet
lr = 0.008 # max learning rate lr = 0.008 # max learning rate
optimizer = dict(type='AdamW', lr=lr, weight_decay=0.01) optim_wrapper = dict(
optimizer_config = dict(grad_clip=dict(max_norm=10, norm_type=2)) type='OptimWrapper',
lr_config = dict(policy='step', warmup=None, step=[24, 32]) optimizer=dict(type='AdamW', lr=lr, weight_decay=0.01),
# runtime settings clip_grad=dict(max_norm=10, norm_type=2),
runner = dict(type='EpochBasedRunner', max_epochs=36) )
# training schedule for 1x
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=36, val_interval=1)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
# learning rate
param_scheduler = [
dict(
type='MultiStepLR',
begin=0,
end=36,
by_epoch=True,
milestones=[24, 32],
gamma=0.1)
]
...@@ -31,6 +31,4 @@ model = dict( ...@@ -31,6 +31,4 @@ model = dict(
[1.1511526, 1.0546296, 0.49706793], [1.1511526, 1.0546296, 0.49706793],
[0.47535285, 0.49249494, 0.5802117]]))) [0.47535285, 0.49249494, 0.5802117]])))
# yapf:disable default_hooks = dict(logger=dict(type='LoggerHook', interval=30))
log_config = dict(interval=30)
# yapf:enable
...@@ -51,6 +51,8 @@ class Det3DDataSample(DetDataSample): ...@@ -51,6 +51,8 @@ class Det3DDataSample(DetDataSample):
panoptic segmentation. panoptic segmentation.
- ``pred_pts_panoptic_seg``(PixelData): Predicted of point cloud - ``pred_pts_panoptic_seg``(PixelData): Predicted of point cloud
panoptic segmentation. panoptic segmentation.
- ``eval_ann_info``(dict): Raw annotation, which will be passed to
evaluator and do the online evaluation.
Examples: Examples:
>>> from mmengine.data import InstanceData, PixelData >>> from mmengine.data import InstanceData, PixelData
......
...@@ -205,7 +205,6 @@ def indoor_eval(gt_annos, ...@@ -205,7 +205,6 @@ def indoor_eval(gt_annos,
metric, metric,
label2cat, label2cat,
logger=None, logger=None,
box_type_3d=None,
box_mode_3d=None): box_mode_3d=None):
"""Indoor Evaluation. """Indoor Evaluation.
...@@ -217,11 +216,11 @@ def indoor_eval(gt_annos, ...@@ -217,11 +216,11 @@ def indoor_eval(gt_annos,
includes the following keys includes the following keys
- labels_3d (torch.Tensor): Labels of boxes. - labels_3d (torch.Tensor): Labels of boxes.
- boxes_3d (:obj:`BaseInstance3DBoxes`): - bboxes_3d (:obj:`BaseInstance3DBoxes`):
3D bounding boxes in Depth coordinate. 3D bounding boxes in Depth coordinate.
- scores_3d (torch.Tensor): Scores of boxes. - scores_3d (torch.Tensor): Scores of boxes.
metric (list[float]): IoU thresholds for computing average precisions. metric (list[float]): IoU thresholds for computing average precisions.
label2cat (dict): Map from label to category. label2cat (tuple): Map from label to category.
logger (logging.Logger | str, optional): The way to print the mAP logger (logging.Logger | str, optional): The way to print the mAP
summary. See `mmdet.utils.print_log()` for details. Default: None. summary. See `mmdet.utils.print_log()` for details. Default: None.
...@@ -236,7 +235,7 @@ def indoor_eval(gt_annos, ...@@ -236,7 +235,7 @@ def indoor_eval(gt_annos,
det_anno = dt_annos[img_id] det_anno = dt_annos[img_id]
for i in range(len(det_anno['labels_3d'])): for i in range(len(det_anno['labels_3d'])):
label = det_anno['labels_3d'].numpy()[i] label = det_anno['labels_3d'].numpy()[i]
bbox = det_anno['boxes_3d'].convert_to(box_mode_3d)[i] bbox = det_anno['bboxes_3d'].convert_to(box_mode_3d)[i]
score = det_anno['scores_3d'].numpy()[i] score = det_anno['scores_3d'].numpy()[i]
if label not in pred: if label not in pred:
pred[int(label)] = {} pred[int(label)] = {}
...@@ -250,15 +249,9 @@ def indoor_eval(gt_annos, ...@@ -250,15 +249,9 @@ def indoor_eval(gt_annos,
# parse gt annotations # parse gt annotations
gt_anno = gt_annos[img_id] gt_anno = gt_annos[img_id]
if gt_anno['gt_num'] != 0:
gt_boxes = box_type_3d( gt_boxes = gt_anno['gt_bboxes_3d']
gt_anno['gt_boxes_upright_depth'], labels_3d = gt_anno['gt_labels_3d']
box_dim=gt_anno['gt_boxes_upright_depth'].shape[-1],
origin=(0.5, 0.5, 0.5)).convert_to(box_mode_3d)
labels_3d = gt_anno['class']
else:
gt_boxes = box_type_3d(np.array([], dtype=np.float32))
labels_3d = np.array([], dtype=np.int64)
for i in range(len(labels_3d)): for i in range(len(labels_3d)):
label = labels_3d[i] label = labels_3d[i]
......
...@@ -51,7 +51,7 @@ def merge_aug_bboxes_3d(aug_results, aug_batch_input_metas, test_cfg): ...@@ -51,7 +51,7 @@ def merge_aug_bboxes_3d(aug_results, aug_batch_input_metas, test_cfg):
aug_labels = torch.cat(recovered_labels, dim=0) aug_labels = torch.cat(recovered_labels, dim=0)
# TODO: use a more elegent way to deal with nms # TODO: use a more elegent way to deal with nms
if test_cfg.use_rotate_nms: if test_cfg.get('use_rotate_nms', False):
nms_func = nms_bev nms_func = nms_bev
else: else:
nms_func = nms_normal_bev nms_func = nms_normal_bev
...@@ -83,7 +83,7 @@ def merge_aug_bboxes_3d(aug_results, aug_batch_input_metas, test_cfg): ...@@ -83,7 +83,7 @@ def merge_aug_bboxes_3d(aug_results, aug_batch_input_metas, test_cfg):
merged_labels = torch.cat(merged_labels, dim=0) merged_labels = torch.cat(merged_labels, dim=0)
_, order = merged_scores.sort(0, descending=True) _, order = merged_scores.sort(0, descending=True)
num = min(test_cfg.max_num, len(aug_bboxes)) num = min(test_cfg.get('max_num', 500), len(aug_bboxes))
order = order[:num] order = order[:num]
merged_bboxes = merged_bboxes[order] merged_bboxes = merged_bboxes[order]
......
...@@ -47,10 +47,15 @@ class Det3DDataset(BaseDataset): ...@@ -47,10 +47,15 @@ class Det3DDataset(BaseDataset):
- 'Camera': Box in camera coordinates, usually - 'Camera': Box in camera coordinates, usually
for vision-based 3d detection. for vision-based 3d detection.
filter_empty_gt (bool, optional): Whether to filter the data with filter_empty_gt (bool): Whether to filter the data with
empty GT. Defaults to True. empty GT. Defaults to True.
test_mode (bool, optional): Whether the dataset is in test mode. test_mode (bool): Whether the dataset is in test mode.
Defaults to False. Defaults to False.
load_eval_anns (bool): Whether to load annotations
in test_mode, the annotation will be save in
`eval_ann_infos`, which can be use in Evaluator.
file_client_args (dict): Configuration of file client.
Defaults to `dict(backend='disk')`.
""" """
def __init__(self, def __init__(self,
...@@ -63,11 +68,13 @@ class Det3DDataset(BaseDataset): ...@@ -63,11 +68,13 @@ class Det3DDataset(BaseDataset):
box_type_3d: dict = 'LiDAR', box_type_3d: dict = 'LiDAR',
filter_empty_gt: bool = True, filter_empty_gt: bool = True,
test_mode: bool = False, test_mode: bool = False,
load_eval_anns=True,
file_client_args: dict = dict(backend='disk'), file_client_args: dict = dict(backend='disk'),
**kwargs): **kwargs):
# init file client # init file client
self.file_client = mmcv.FileClient(**file_client_args) self.file_client = mmcv.FileClient(**file_client_args)
self.filter_empty_gt = filter_empty_gt self.filter_empty_gt = filter_empty_gt
self.load_eval_anns = load_eval_anns
_default_modality_keys = ('use_lidar', 'use_camera') _default_modality_keys = ('use_lidar', 'use_camera')
if modality is None: if modality is None:
modality = dict() modality = dict()
...@@ -82,7 +89,6 @@ class Det3DDataset(BaseDataset): ...@@ -82,7 +89,6 @@ class Det3DDataset(BaseDataset):
f', `use_camera`) for {self.__class__.__name__}') f', `use_camera`) for {self.__class__.__name__}')
self.box_type_3d, self.box_mode_3d = get_box_type(box_type_3d) self.box_type_3d, self.box_mode_3d = get_box_type(box_type_3d)
if metainfo is not None and 'CLASSES' in metainfo: if metainfo is not None and 'CLASSES' in metainfo:
# we allow to train on subset of self.METAINFO['CLASSES'] # we allow to train on subset of self.METAINFO['CLASSES']
# map unselected labels to -1 # map unselected labels to -1
...@@ -101,6 +107,10 @@ class Det3DDataset(BaseDataset): ...@@ -101,6 +107,10 @@ class Det3DDataset(BaseDataset):
} }
self.label_mapping[-1] = -1 self.label_mapping[-1] = -1
# can be accessed by other component in runner
metainfo['box_type_3d'] = box_type_3d
metainfo['label_mapping'] = self.label_mapping
super().__init__( super().__init__(
ann_file=ann_file, ann_file=ann_file,
metainfo=metainfo, metainfo=metainfo,
...@@ -221,7 +231,10 @@ class Det3DDataset(BaseDataset): ...@@ -221,7 +231,10 @@ class Det3DDataset(BaseDataset):
self.data_prefix.get('img', ''), img_info['img_path']) self.data_prefix.get('img', ''), img_info['img_path'])
if not self.test_mode: if not self.test_mode:
# used in traing
info['ann_info'] = self.parse_ann_info(info) info['ann_info'] = self.parse_ann_info(info)
if self.test_mode and self.load_eval_anns:
info['eval_ann_info'] = self.parse_ann_info(info)
return info return info
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Union
import numpy as np import numpy as np
from mmcv import BaseTransform from mmcv import BaseTransform
from mmcv.transforms import to_tensor from mmcv.transforms import to_tensor
...@@ -45,14 +47,16 @@ class Pack3DDetInputs(BaseTransform): ...@@ -45,14 +47,16 @@ class Pack3DDetInputs(BaseTransform):
key = key[3:] key = key[3:]
return key return key
def transform(self, results: dict) -> dict: def transform(self, results: Union[dict,
"""Method to pack the input data. List[dict]]) -> Union[dict, List[dict]]:
"""Method to pack the input data. when the value in this dict is a
list, it usually is in Augmentations Testing.
Args: Args:
results (dict): Result dict from the data pipeline. results (dict | list[dict]): Result dict from the data pipeline.
Returns: Returns:
dict: dict | List[dict]:
- 'inputs' (dict): The forward data of models. It usually contains - 'inputs' (dict): The forward data of models. It usually contains
following keys: following keys:
...@@ -63,11 +67,40 @@ class Pack3DDetInputs(BaseTransform): ...@@ -63,11 +67,40 @@ class Pack3DDetInputs(BaseTransform):
- 'data_sample' (obj:`Det3DDataSample`): The annotation info of the - 'data_sample' (obj:`Det3DDataSample`): The annotation info of the
sample. sample.
""" """
packed_results = dict() # augtest
if isinstance(results, list):
pack_results = []
for single_result in results:
pack_results.append(self.pack_single_results(single_result))
return pack_results
# norm training and simple testing
elif isinstance(results, dict):
return self.pack_single_results(results)
else:
raise NotImplementedError
def pack_single_results(self, results):
"""Method to pack the single input data. when the value in this dict is
a list, it usually is in Augmentations Testing.
Args:
results (dict): Result dict from the data pipeline.
Returns:
dict: A dict contains
- 'inputs' (dict): The forward data of models. It usually contains
following keys:
- points
- img
- 'data_sample' (obj:`Det3DDataSample`): The annotation info of the
sample.
"""
# Format 3D data # Format 3D data
if 'points' in results: if 'points' in results:
assert isinstance(results['points'], BasePoints) if isinstance(results['points'], BasePoints):
results['points'] = results['points'].tensor results['points'] = results['points'].tensor
if 'img' in results: if 'img' in results:
...@@ -134,6 +167,12 @@ class Pack3DDetInputs(BaseTransform): ...@@ -134,6 +167,12 @@ class Pack3DDetInputs(BaseTransform):
data_sample.gt_instances_3d = gt_instances_3d data_sample.gt_instances_3d = gt_instances_3d
data_sample.gt_instances = gt_instances data_sample.gt_instances = gt_instances
data_sample.seg_data = seg_data data_sample.seg_data = seg_data
if 'eval_ann_info' in results:
data_sample.eval_ann_info = results['eval_ann_info']
else:
data_sample.eval_ann_info = None
packed_results = dict()
packed_results['data_sample'] = data_sample packed_results['data_sample'] = data_sample
packed_results['inputs'] = inputs packed_results['inputs'] = inputs
......
...@@ -684,6 +684,9 @@ class LoadAnnotations3D(LoadAnnotations): ...@@ -684,6 +684,9 @@ class LoadAnnotations3D(LoadAnnotations):
pts_instance_mask_path, dtype=np.int64) pts_instance_mask_path, dtype=np.int64)
results['pts_instance_mask'] = pts_instance_mask results['pts_instance_mask'] = pts_instance_mask
# 'eval_ann_info' will be passed to evaluator
if 'eval_ann_info' in results:
results['eval_ann_info']['pts_instance_mask'] = pts_instance_mask
return results return results
def _load_semantic_seg_3d(self, results: dict) -> dict: def _load_semantic_seg_3d(self, results: dict) -> dict:
...@@ -710,6 +713,9 @@ class LoadAnnotations3D(LoadAnnotations): ...@@ -710,6 +713,9 @@ class LoadAnnotations3D(LoadAnnotations):
pts_semantic_mask_path, dtype=np.int64) pts_semantic_mask_path, dtype=np.int64)
results['pts_semantic_mask'] = pts_semantic_mask results['pts_semantic_mask'] = pts_semantic_mask
# 'eval_ann_info' will be passed to evaluator
if 'eval_ann_info' in results:
results['eval_ann_info']['pts_semantic_mask'] = pts_semantic_mask
return results return results
def transform(self, results: dict) -> dict: def transform(self, results: dict) -> dict:
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import warnings import warnings
from copy import deepcopy from copy import deepcopy
from typing import Dict, List, Optional, Tuple, Union
import mmcv import mmcv
from mmcv import BaseTransform
from mmengine.dataset import Compose
from mmdet3d.registry import TRANSFORMS from mmdet3d.registry import TRANSFORMS
from .compose import Compose
@TRANSFORMS.register_module() @TRANSFORMS.register_module()
class MultiScaleFlipAug3D(object): class MultiScaleFlipAug3D(BaseTransform):
"""Test-time augmentation with multiple scales and flipping. """Test-time augmentation with multiple scales and flipping.
Args: Args:
...@@ -33,13 +35,13 @@ class MultiScaleFlipAug3D(object): ...@@ -33,13 +35,13 @@ class MultiScaleFlipAug3D(object):
""" """
def __init__(self, def __init__(self,
transforms, transforms: List[dict],
img_scale, img_scale: Optional[Union[Tuple[int], List[Tuple[int]]]],
pts_scale_ratio, pts_scale_ratio: Union[float, List[float]],
flip=False, flip: bool = False,
flip_direction='horizontal', flip_direction: str = 'horizontal',
pcd_horizontal_flip=False, pcd_horizontal_flip: bool = False,
pcd_vertical_flip=False): pcd_vertical_flip: bool = False) -> None:
self.transforms = Compose(transforms) self.transforms = Compose(transforms)
self.img_scale = img_scale if isinstance(img_scale, self.img_scale = img_scale if isinstance(img_scale,
list) else [img_scale] list) else [img_scale]
...@@ -65,17 +67,17 @@ class MultiScaleFlipAug3D(object): ...@@ -65,17 +67,17 @@ class MultiScaleFlipAug3D(object):
warnings.warn( warnings.warn(
'flip has no effect when RandomFlip is not in transforms') 'flip has no effect when RandomFlip is not in transforms')
def __call__(self, results): def transform(self, results: Dict) -> List[Dict]:
"""Call function to augment common fields in results. """Call function to augment common fields in results.
Args: Args:
results (dict): Result dict contains the data to augment. results (dict): Result dict contains the data to augment.
Returns: Returns:
dict: The result dict contains the data that is augmented with List[dict]: The list contains the data that is augmented with
different scales and flips. different scales and flips.
""" """
aug_data = [] aug_data_list = []
# modified from `flip_aug = [False, True] if self.flip else [False]` # modified from `flip_aug = [False, True] if self.flip else [False]`
# to reduce unnecessary scenes when using double flip augmentation # to reduce unnecessary scenes when using double flip augmentation
...@@ -104,13 +106,9 @@ class MultiScaleFlipAug3D(object): ...@@ -104,13 +106,9 @@ class MultiScaleFlipAug3D(object):
_results['pcd_vertical_flip'] = \ _results['pcd_vertical_flip'] = \
pcd_vertical_flip pcd_vertical_flip
data = self.transforms(_results) data = self.transforms(_results)
aug_data.append(data) aug_data_list.append(data)
# list of dict to dict of list
aug_data_dict = {key: [] for key in aug_data[0]} return aug_data_list
for data in aug_data:
for key, val in data.items():
aug_data_dict[key].append(val)
return aug_data_dict
def __repr__(self): def __repr__(self):
"""str: Return a string that describes the module.""" """str: Return a string that describes the module."""
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .indoor_metric import IndoorMetric # noqa: F401,F403
from .kitti_metric import KittiMetric # noqa: F401,F403 from .kitti_metric import KittiMetric # noqa: F401,F403
__all_ = ['KittiMetric'] __all_ = ['KittiMetric', 'IndoorMetric']
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Sequence
from mmengine.evaluator import BaseMetric
from mmengine.logging import MMLogger
from mmdet3d.core import get_box_type, indoor_eval
from mmdet3d.registry import METRICS
@METRICS.register_module()
class IndoorMetric(BaseMetric):
"""Kitti evaluation metric.
Args:
iou_thr (list[float]): List of iou threshold when calculate the
metric. Defaults to [0.25, 0.5].
collect_device (str, optional): Device name used for collecting
results from different ranks during distributed training.
Must be 'cpu' or 'gpu'. Defaults to 'cpu'.
prefix (str): The prefix that will be added in the metric
names to disambiguate homonymous metrics of different evaluators.
If prefix is not provided in the argument, self.default_prefix
will be used instead. Default: None
"""
def __init__(self,
iou_thr: List[float] = [0.25, 0.5],
collect_device: str = 'cpu',
prefix: Optional[str] = None,
**kwargs):
super(IndoorMetric, self).__init__(
prefix=prefix, collect_device=collect_device)
self.iou_thr = iou_thr
def process(self, data_batch: Sequence[dict],
predictions: Sequence[dict]) -> None:
"""Process one batch of data samples and predictions.
The processed results should be stored in ``self.results``,
which will be used to compute the metrics when all batches
have been processed.
Args:
data_batch (Sequence[dict]): A batch of data
from the dataloader.
predictions (Sequence[dict]): A batch of outputs from
the model.
"""
batch_eval_anns = [
item['data_sample']['eval_ann_info'] for item in data_batch
]
for eval_ann, pred_dict in zip(batch_eval_anns, predictions):
pred_3d = pred_dict['pred_instances_3d']
cpu_pred_3d = dict()
for k, v in pred_3d.items():
if hasattr(v, 'to'):
cpu_pred_3d[k] = v.to('cpu')
else:
cpu_pred_3d[k] = v
self.results.append((eval_ann, cpu_pred_3d))
def compute_metrics(self, results: list) -> Dict[str, float]:
"""Compute the metrics from processed results.
Args:
results (list): The processed results of each batch.
Returns:
Dict[str, float]: The computed metrics. The keys are the names of
the metrics, and the values are corresponding results.
"""
logger: MMLogger = MMLogger.get_current_instance()
ann_infos = []
pred_results = []
for eval_ann, sinlge_pred_results in results:
ann_infos.append(eval_ann)
pred_results.append(sinlge_pred_results)
box_type_3d, box_mode_3d = get_box_type(
self.dataset_meta['box_type_3d'])
ret_dict = indoor_eval(
ann_infos,
pred_results,
self.iou_thr,
self.dataset_meta['CLASSES'],
logger=logger,
box_mode_3d=box_mode_3d)
return ret_dict
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from numbers import Number from numbers import Number
from typing import Dict, List, Optional, Sequence, Tuple, Union from typing import List, Optional, Sequence, Tuple, Union
import numpy as np import numpy as np
from mmengine.data import BaseDataElement from mmengine.data import BaseDataElement
...@@ -66,19 +66,41 @@ class Det3DDataPreprocessor(DetDataPreprocessor): ...@@ -66,19 +66,41 @@ class Det3DDataPreprocessor(DetDataPreprocessor):
batch_augments=batch_augments) batch_augments=batch_augments)
def forward(self, def forward(self,
data: Sequence[dict], data: List[Union[dict, List[dict]]],
training: bool = False) -> Tuple[Dict, Optional[list]]: training: bool = False
) -> Tuple[Union[dict, List[dict]], Optional[list]]:
"""Perform normalization、padding and bgr2rgb conversion based on """Perform normalization、padding and bgr2rgb conversion based on
``BaseDataPreprocessor``. ``BaseDataPreprocessor``.
Args: Args:
data (Sequence[dict]): data sampled from dataloader. data (List[dict] | List[List[dict]]): data from dataloader.
The outer list always represent the batch size, when it is
a list[list[dict]], the inter list indicate test time
augmentation.
training (bool): Whether to enable training time augmentation. training (bool): Whether to enable training time augmentation.
Returns: Returns:
Tuple[Dict, Optional[list]]: Data in the same format as the Tuple[Dict, Optional[list]] |
model input. Tuple[List[Dict], Optional[list[list]]]:
Data in the same format as the model input.
""" """
if isinstance(data[0], list):
num_augs = len(data[0])
aug_batch_data = []
aug_batch_data_sample = []
for aug_id in range(num_augs):
single_aug_batch_data, \
single_aug_batch_data_sample = self.simple_process(
[item[aug_id] for item in data], training)
aug_batch_data.append(single_aug_batch_data)
aug_batch_data_sample.append(single_aug_batch_data_sample)
return aug_batch_data, aug_batch_data_sample
else:
return self.simple_process(data, training)
def simple_process(self, data: Sequence[dict], training: bool = False):
inputs_dict, batch_data_samples = self.collate_data(data) inputs_dict, batch_data_samples = self.collate_data(data)
if 'points' in inputs_dict[0].keys(): if 'points' in inputs_dict[0].keys():
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Union
import numpy as np import numpy as np
import torch import torch
from mmcv.ops import furthest_point_sample from mmcv.ops import furthest_point_sample
from mmcv.runner import BaseModule, force_fp32 from mmcv.runner import BaseModule, force_fp32
from mmengine import ConfigDict, InstanceData
from torch.nn import functional as F from torch.nn import functional as F
from mmdet3d.core.post_processing import aligned_3d_nms from mmdet3d.core.post_processing import aligned_3d_nms
from mmdet3d.models.builder import build_loss
from mmdet3d.models.losses import chamfer_distance from mmdet3d.models.losses import chamfer_distance
from mmdet3d.models.model_utils import VoteModule from mmdet3d.models.model_utils import VoteModule
from mmdet3d.ops import build_sa_module from mmdet3d.ops import build_sa_module
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS, TASK_UTILS
from mmdet.core import build_bbox_coder, multi_apply from mmdet.core.utils import multi_apply
from ...core import Det3DDataSample
from .base_conv_bbox_head import BaseConvBboxHead from .base_conv_bbox_head import BaseConvBboxHead
...@@ -21,66 +24,76 @@ class VoteHead(BaseModule): ...@@ -21,66 +24,76 @@ class VoteHead(BaseModule):
Args: Args:
num_classes (int): The number of class. num_classes (int): The number of class.
bbox_coder (:obj:`BaseBBoxCoder`): Bbox coder for encoding and bbox_coder (ConfigDict, dict): Bbox coder for encoding and
decoding boxes. decoding boxes. Defaults to None.
train_cfg (dict): Config for training. train_cfg (dict, optional): Config for training. Defaults to None.
test_cfg (dict): Config for testing. test_cfg (dict, optional): Config for testing. Defaults to None.
vote_module_cfg (dict): Config of VoteModule for point-wise votes. vote_module_cfg (dict, optional): Config of VoteModule for
vote_aggregation_cfg (dict): Config of vote aggregation layer. point-wise votes. Defaults to None.
pred_layer_cfg (dict): Config of classfication and regression vote_aggregation_cfg (dict, optional): Config of vote
prediction layers. aggregation layer. Defaults to None.
conv_cfg (dict): Config of convolution in prediction layer. pred_layer_cfg (dict, optional): Config of classification
norm_cfg (dict): Config of BN in prediction layer. and regression prediction layers. Defaults to None.
objectness_loss (dict): Config of objectness loss. objectness_loss (dict, optional): Config of objectness loss.
center_loss (dict): Config of center loss. Defaults to None.
dir_class_loss (dict): Config of direction classification loss. center_loss (dict, optional): Config of center loss.
dir_res_loss (dict): Config of direction residual regression loss. Defaults to None.
size_class_loss (dict): Config of size classification loss. dir_class_loss (dict, optional): Config of direction
size_res_loss (dict): Config of size residual regression loss. classification loss. Defaults to None.
semantic_loss (dict): Config of point-wise semantic segmentation loss. dir_res_loss (dict, optional): Config of direction
residual regression loss. Defaults to None.
size_class_loss (dict, optional): Config of size
classification loss. Defaults to None.
size_res_loss (dict, optional): Config of size
residual regression loss. Defaults to None.
semantic_loss (dict, optional): Config of point-wise
semantic segmentation loss. Defaults to None.
iou_loss (dict, optional): Config of IOU loss for
regression. Defaults to None.
init_cfg (dict, optional): Config of model weight
initialization. Defaults to None.
""" """
def __init__(self, def __init__(self,
num_classes, num_classes: int,
bbox_coder, bbox_coder: Union[ConfigDict, dict],
train_cfg=None, train_cfg: Optional[dict] = None,
test_cfg=None, test_cfg: Optional[dict] = None,
vote_module_cfg=None, vote_module_cfg: Optional[dict] = None,
vote_aggregation_cfg=None, vote_aggregation_cfg: Optional[dict] = None,
pred_layer_cfg=None, pred_layer_cfg: Optional[dict] = None,
conv_cfg=dict(type='Conv1d'), objectness_loss: Optional[dict] = None,
norm_cfg=dict(type='BN1d'), center_loss: Optional[dict] = None,
objectness_loss=None, dir_class_loss: Optional[dict] = None,
center_loss=None, dir_res_loss: Optional[dict] = None,
dir_class_loss=None, size_class_loss: Optional[dict] = None,
dir_res_loss=None, size_res_loss: Optional[dict] = None,
size_class_loss=None, semantic_loss: Optional[dict] = None,
size_res_loss=None, iou_loss: Optional[dict] = None,
semantic_loss=None, init_cfg: Optional[dict] = None):
iou_loss=None,
init_cfg=None):
super(VoteHead, self).__init__(init_cfg=init_cfg) super(VoteHead, self).__init__(init_cfg=init_cfg)
self.num_classes = num_classes self.num_classes = num_classes
self.train_cfg = train_cfg self.train_cfg = train_cfg
self.test_cfg = test_cfg self.test_cfg = test_cfg
self.gt_per_seed = vote_module_cfg['gt_per_seed'] self.gt_per_seed = vote_module_cfg['gt_per_seed']
self.num_proposal = vote_aggregation_cfg['num_point'] self.num_proposal = vote_aggregation_cfg['num_point']
self.objectness_loss = build_loss(objectness_loss) self.loss_objectness = MODELS.build(objectness_loss)
self.center_loss = build_loss(center_loss) self.loss_center = MODELS.build(center_loss)
self.dir_res_loss = build_loss(dir_res_loss) self.loss_dir_res = MODELS.build(dir_res_loss)
self.dir_class_loss = build_loss(dir_class_loss) self.loss_dir_class = MODELS.build(dir_class_loss)
self.size_res_loss = build_loss(size_res_loss) self.loss_size_res = MODELS.build(size_res_loss)
if size_class_loss is not None: if size_class_loss is not None:
self.size_class_loss = build_loss(size_class_loss) self.size_class_loss = MODELS.build(size_class_loss)
if semantic_loss is not None: if semantic_loss is not None:
self.semantic_loss = build_loss(semantic_loss) self.semantic_loss = MODELS.build(semantic_loss)
if iou_loss is not None: if iou_loss is not None:
self.iou_loss = build_loss(iou_loss) self.iou_loss = MODELS.build(iou_loss)
else: else:
self.iou_loss = None self.iou_loss = None
self.bbox_coder = build_bbox_coder(bbox_coder) self.bbox_coder = TASK_UTILS.build(bbox_coder)
self.num_sizes = self.bbox_coder.num_sizes self.num_sizes = self.bbox_coder.num_sizes
self.num_dir_bins = self.bbox_coder.num_dir_bins self.num_dir_bins = self.bbox_coder.num_dir_bins
...@@ -94,6 +107,15 @@ class VoteHead(BaseModule): ...@@ -94,6 +107,15 @@ class VoteHead(BaseModule):
num_cls_out_channels=self._get_cls_out_channels(), num_cls_out_channels=self._get_cls_out_channels(),
num_reg_out_channels=self._get_reg_out_channels()) num_reg_out_channels=self._get_reg_out_channels())
@property
def sample_mode(self):
if self.training:
sample_mode = self.train_cfg.sample_mode
else:
sample_mode = self.test_cfg.sample_mode
assert sample_mode in ['vote', 'seed', 'random', 'spec']
return sample_mode
def _get_cls_out_channels(self): def _get_cls_out_channels(self):
"""Return the channel number of classification outputs.""" """Return the channel number of classification outputs."""
# Class numbers (k) + objectness (2) # Class numbers (k) + objectness (2)
...@@ -106,16 +128,18 @@ class VoteHead(BaseModule): ...@@ -106,16 +128,18 @@ class VoteHead(BaseModule):
# size class+residual(num_sizes*4) # size class+residual(num_sizes*4)
return 3 + self.num_dir_bins * 2 + self.num_sizes * 4 return 3 + self.num_dir_bins * 2 + self.num_sizes * 4
def _extract_input(self, feat_dict): def _extract_input(self, feat_dict: dict) -> tuple:
"""Extract inputs from features dictionary. """Extract inputs from features dictionary.
Args: Args:
feat_dict (dict): Feature dict from backbone. feat_dict (dict): Feature dict from backbone.
Returns: Returns:
torch.Tensor: Coordinates of input points. tuple[Tensor]: Arrage as following three tensor.
torch.Tensor: Features of input points.
torch.Tensor: Indices of input points. - Coordinates of input points.
- Features of input points.
- Indices of input points.
""" """
# for imvotenet # for imvotenet
...@@ -133,7 +157,77 @@ class VoteHead(BaseModule): ...@@ -133,7 +157,77 @@ class VoteHead(BaseModule):
return seed_points, seed_features, seed_indices return seed_points, seed_features, seed_indices
def forward(self, feat_dict, sample_mod): def predict(self,
points: List[torch.Tensor],
feats_dict: Dict[str, torch.Tensor],
batch_data_samples: List[Det3DDataSample],
rescale=True,
**kwargs) -> List[InstanceData]:
"""
Args:
points (list[tensor]): Point clouds of multiple samples.
feats_dict (dict): Features from FPN or backbone..
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
Samples. It usually includes meta information of data.
rescale (bool): Whether rescale the resutls to
the original scale.
Returns:
list[:obj:`InstanceData`]: List of processed predictions. Each
InstanceData contains 3d Bounding boxes and corresponding
scores and labels.
"""
preds_dict = self(feats_dict)
batch_size = len(batch_data_samples)
batch_input_metas = []
for batch_index in range(batch_size):
metainfo = batch_data_samples[batch_index].metainfo
batch_input_metas.append(metainfo)
results_list = self.predict_by_feat(
points, preds_dict, batch_input_metas, rescale=rescale, **kwargs)
return results_list
def loss(self, points: List[torch.Tensor], feats_dict: Dict[str,
torch.Tensor],
batch_data_samples: List[Det3DDataSample], **kwargs) -> dict:
"""
Args:
points (list[tensor]): Points cloud of multiple samples.
feats_dict (dict): Predictions from backbone or FPN.
batch_data_samples (list[:obj:`Det3DDataSample`]): Each item
contains the meta information of each sample and
corresponding annotations.
Returns:
dict: A dictionary of loss components.
"""
preds_dict = self.forward(feats_dict)
batch_gt_instance_3d = []
batch_gt_instances_ignore = []
batch_input_metas = []
batch_pts_semantic_mask = []
batch_pts_instance_mask = []
for data_sample in batch_data_samples:
batch_input_metas.append(data_sample.metainfo)
batch_gt_instance_3d.append(data_sample.gt_instances_3d)
batch_gt_instances_ignore.append(
data_sample.get('ignored_instances', None))
batch_pts_semantic_mask.append(
data_sample.seg_data.get('pts_semantic_mask', None))
batch_pts_instance_mask.append(
data_sample.seg_data.get('pts_instance_mask', None))
loss_inputs = (points, preds_dict, batch_gt_instance_3d)
losses = self.loss_by_feat(
*loss_inputs,
batch_pts_semantic_mask=batch_pts_semantic_mask,
batch_pts_instance_mask=batch_pts_instance_mask,
batch_input_metas=batch_input_metas,
batch_gt_instances_ignore=batch_gt_instances_ignore)
return losses
def forward(self, feat_dict: dict) -> dict:
"""Forward pass. """Forward pass.
Note: Note:
...@@ -146,13 +240,10 @@ class VoteHead(BaseModule): ...@@ -146,13 +240,10 @@ class VoteHead(BaseModule):
Args: Args:
feat_dict (dict): Feature dict from backbone. feat_dict (dict): Feature dict from backbone.
sample_mod (str): Sample mode for vote aggregation layer.
valid modes are "vote", "seed", "random" and "spec".
Returns: Returns:
dict: Predictions of vote head. dict: Predictions of vote head.
""" """
assert sample_mod in ['vote', 'seed', 'random', 'spec']
seed_points, seed_features, seed_indices = self._extract_input( seed_points, seed_features, seed_indices = self._extract_input(
feat_dict) feat_dict)
...@@ -168,11 +259,11 @@ class VoteHead(BaseModule): ...@@ -168,11 +259,11 @@ class VoteHead(BaseModule):
vote_offset=vote_offset) vote_offset=vote_offset)
# 2. aggregate vote_points # 2. aggregate vote_points
if sample_mod == 'vote': if self.sample_mode == 'vote':
# use fps in vote_aggregation # use fps in vote_aggregation
aggregation_inputs = dict( aggregation_inputs = dict(
points_xyz=vote_points, features=vote_features) points_xyz=vote_points, features=vote_features)
elif sample_mod == 'seed': elif self.sample_mode == 'seed':
# FPS on seed and choose the votes corresponding to the seeds # FPS on seed and choose the votes corresponding to the seeds
sample_indices = furthest_point_sample(seed_points, sample_indices = furthest_point_sample(seed_points,
self.num_proposal) self.num_proposal)
...@@ -180,7 +271,7 @@ class VoteHead(BaseModule): ...@@ -180,7 +271,7 @@ class VoteHead(BaseModule):
points_xyz=vote_points, points_xyz=vote_points,
features=vote_features, features=vote_features,
indices=sample_indices) indices=sample_indices)
elif sample_mod == 'random': elif self.sample_mode == 'random':
# Random sampling from the votes # Random sampling from the votes
batch_size, num_seed = seed_points.shape[:2] batch_size, num_seed = seed_points.shape[:2]
sample_indices = seed_points.new_tensor( sample_indices = seed_points.new_tensor(
...@@ -190,7 +281,7 @@ class VoteHead(BaseModule): ...@@ -190,7 +281,7 @@ class VoteHead(BaseModule):
points_xyz=vote_points, points_xyz=vote_points,
features=vote_features, features=vote_features,
indices=sample_indices) indices=sample_indices)
elif sample_mod == 'spec': elif self.sample_mode == 'spec':
# Specify the new center in vote_aggregation # Specify the new center in vote_aggregation
aggregation_inputs = dict( aggregation_inputs = dict(
points_xyz=seed_points, points_xyz=seed_points,
...@@ -198,7 +289,7 @@ class VoteHead(BaseModule): ...@@ -198,7 +289,7 @@ class VoteHead(BaseModule):
target_xyz=vote_points) target_xyz=vote_points)
else: else:
raise NotImplementedError( raise NotImplementedError(
f'Sample mode {sample_mod} is not supported!') f'Sample mode {self.sample_mode} is not supported!')
vote_aggregation_ret = self.vote_aggregation(**aggregation_inputs) vote_aggregation_ret = self.vote_aggregation(**aggregation_inputs)
aggregated_points, features, aggregated_indices = vote_aggregation_ret aggregated_points, features, aggregated_indices = vote_aggregation_ret
...@@ -214,45 +305,42 @@ class VoteHead(BaseModule): ...@@ -214,45 +305,42 @@ class VoteHead(BaseModule):
decode_res = self.bbox_coder.split_pred(cls_predictions, decode_res = self.bbox_coder.split_pred(cls_predictions,
reg_predictions, reg_predictions,
aggregated_points) aggregated_points)
results.update(decode_res) results.update(decode_res)
return results return results
@force_fp32(apply_to=('bbox_preds', )) @force_fp32(apply_to=('bbox_preds', ))
def loss(self, def loss_by_feat(
bbox_preds, self,
points, points: List[torch.Tensor],
gt_bboxes_3d, bbox_preds_dict: dict,
gt_labels_3d, batch_gt_instances_3d: List[InstanceData],
pts_semantic_mask=None, batch_pts_semantic_mask: Optional[List[torch.Tensor]] = None,
pts_instance_mask=None, batch_pts_instance_mask: Optional[List[torch.Tensor]] = None,
img_metas=None, ret_target: bool = False,
gt_bboxes_ignore=None, **kwargs) -> dict:
ret_target=False):
"""Compute loss. """Compute loss.
Args: Args:
bbox_preds (dict): Predictions from forward of vote head.
points (list[torch.Tensor]): Input points. points (list[torch.Tensor]): Input points.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth bbox_preds_dict (dict): Predictions from forward of vote head.
bboxes of each sample. batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
gt_labels_3d (list[torch.Tensor]): Labels of each sample. gt_instances. It usually includes ``bboxes`` and ``labels``
pts_semantic_mask (list[torch.Tensor]): Point-wise attributes.
semantic mask. batch_pts_semantic_mask (list[tensor]): Semantic mask
pts_instance_mask (list[torch.Tensor]): Point-wise of points cloud. Defaults to None.
instance mask. batch_pts_semantic_mask (list[tensor]): Instance mask
img_metas (list[dict]): Contain pcd and img's meta info. of points cloud. Defaults to None.
gt_bboxes_ignore (list[torch.Tensor]): Specify batch_input_metas (list[dict]): Contain pcd and img's meta info.
which bounding. ret_target (bool): Return targets or not.
ret_target (Bool): Return targets or not.
Returns: Returns:
dict: Losses of Votenet. dict: Losses of Votenet.
""" """
targets = self.get_targets(points, gt_bboxes_3d, gt_labels_3d,
pts_semantic_mask, pts_instance_mask, targets = self.get_targets(points, bbox_preds_dict,
bbox_preds) batch_gt_instances_3d,
batch_pts_semantic_mask,
batch_pts_instance_mask)
(vote_targets, vote_target_masks, size_class_targets, size_res_targets, (vote_targets, vote_target_masks, size_class_targets, size_res_targets,
dir_class_targets, dir_res_targets, center_targets, dir_class_targets, dir_res_targets, center_targets,
assigned_center_targets, mask_targets, valid_gt_masks, assigned_center_targets, mask_targets, valid_gt_masks,
...@@ -260,28 +348,28 @@ class VoteHead(BaseModule): ...@@ -260,28 +348,28 @@ class VoteHead(BaseModule):
valid_gt_weights) = targets valid_gt_weights) = targets
# calculate vote loss # calculate vote loss
vote_loss = self.vote_module.get_loss(bbox_preds['seed_points'], vote_loss = self.vote_module.get_loss(bbox_preds_dict['seed_points'],
bbox_preds['vote_points'], bbox_preds_dict['vote_points'],
bbox_preds['seed_indices'], bbox_preds_dict['seed_indices'],
vote_target_masks, vote_targets) vote_target_masks, vote_targets)
# calculate objectness loss # calculate objectness loss
objectness_loss = self.objectness_loss( objectness_loss = self.loss_objectness(
bbox_preds['obj_scores'].transpose(2, 1), bbox_preds_dict['obj_scores'].transpose(2, 1),
objectness_targets, objectness_targets,
weight=objectness_weights) weight=objectness_weights)
# calculate center loss # calculate center loss
source2target_loss, target2source_loss = self.center_loss( source2target_loss, target2source_loss = self.loss_center(
bbox_preds['center'], bbox_preds_dict['center'],
center_targets, center_targets,
src_weight=box_loss_weights, src_weight=box_loss_weights,
dst_weight=valid_gt_weights) dst_weight=valid_gt_weights)
center_loss = source2target_loss + target2source_loss center_loss = source2target_loss + target2source_loss
# calculate direction class loss # calculate direction class loss
dir_class_loss = self.dir_class_loss( dir_class_loss = self.loss_dir_class(
bbox_preds['dir_class'].transpose(2, 1), bbox_preds_dict['dir_class'].transpose(2, 1),
dir_class_targets, dir_class_targets,
weight=box_loss_weights) weight=box_loss_weights)
...@@ -291,13 +379,13 @@ class VoteHead(BaseModule): ...@@ -291,13 +379,13 @@ class VoteHead(BaseModule):
(batch_size, proposal_num, self.num_dir_bins)) (batch_size, proposal_num, self.num_dir_bins))
heading_label_one_hot.scatter_(2, dir_class_targets.unsqueeze(-1), 1) heading_label_one_hot.scatter_(2, dir_class_targets.unsqueeze(-1), 1)
dir_res_norm = torch.sum( dir_res_norm = torch.sum(
bbox_preds['dir_res_norm'] * heading_label_one_hot, -1) bbox_preds_dict['dir_res_norm'] * heading_label_one_hot, -1)
dir_res_loss = self.dir_res_loss( dir_res_loss = self.loss_dir_res(
dir_res_norm, dir_res_targets, weight=box_loss_weights) dir_res_norm, dir_res_targets, weight=box_loss_weights)
# calculate size class loss # calculate size class loss
size_class_loss = self.size_class_loss( size_class_loss = self.size_class_loss(
bbox_preds['size_class'].transpose(2, 1), bbox_preds_dict['size_class'].transpose(2, 1),
size_class_targets, size_class_targets,
weight=box_loss_weights) weight=box_loss_weights)
...@@ -308,17 +396,17 @@ class VoteHead(BaseModule): ...@@ -308,17 +396,17 @@ class VoteHead(BaseModule):
one_hot_size_targets_expand = one_hot_size_targets.unsqueeze( one_hot_size_targets_expand = one_hot_size_targets.unsqueeze(
-1).repeat(1, 1, 1, 3).contiguous() -1).repeat(1, 1, 1, 3).contiguous()
size_residual_norm = torch.sum( size_residual_norm = torch.sum(
bbox_preds['size_res_norm'] * one_hot_size_targets_expand, 2) bbox_preds_dict['size_res_norm'] * one_hot_size_targets_expand, 2)
box_loss_weights_expand = box_loss_weights.unsqueeze(-1).repeat( box_loss_weights_expand = box_loss_weights.unsqueeze(-1).repeat(
1, 1, 3) 1, 1, 3)
size_res_loss = self.size_res_loss( size_res_loss = self.loss_size_res(
size_residual_norm, size_residual_norm,
size_res_targets, size_res_targets,
weight=box_loss_weights_expand) weight=box_loss_weights_expand)
# calculate semantic loss # calculate semantic loss
semantic_loss = self.semantic_loss( semantic_loss = self.semantic_loss(
bbox_preds['sem_scores'].transpose(2, 1), bbox_preds_dict['sem_scores'].transpose(2, 1),
mask_targets, mask_targets,
weight=box_loss_weights) weight=box_loss_weights)
...@@ -334,7 +422,7 @@ class VoteHead(BaseModule): ...@@ -334,7 +422,7 @@ class VoteHead(BaseModule):
if self.iou_loss: if self.iou_loss:
corners_pred = self.bbox_coder.decode_corners( corners_pred = self.bbox_coder.decode_corners(
bbox_preds['center'], size_residual_norm, bbox_preds_dict['center'], size_residual_norm,
one_hot_size_targets_expand) one_hot_size_targets_expand)
corners_target = self.bbox_coder.decode_corners( corners_target = self.bbox_coder.decode_corners(
assigned_center_targets, size_res_targets, assigned_center_targets, size_res_targets,
...@@ -348,25 +436,26 @@ class VoteHead(BaseModule): ...@@ -348,25 +436,26 @@ class VoteHead(BaseModule):
return losses return losses
def get_targets(self, def get_targets(
self,
points, points,
gt_bboxes_3d, bbox_preds: dict = None,
gt_labels_3d, batch_gt_instances_3d: List[InstanceData] = None,
pts_semantic_mask=None, batch_pts_semantic_mask: List[torch.Tensor] = None,
pts_instance_mask=None, batch_pts_instance_mask: List[torch.Tensor] = None,
bbox_preds=None): ):
"""Generate targets of vote head. """Generate targets of vote head.
Args: Args:
points (list[torch.Tensor]): Points of each batch. points (list[torch.Tensor]): Points of each batch.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth
bboxes of each batch.
gt_labels_3d (list[torch.Tensor]): Labels of each batch.
pts_semantic_mask (list[torch.Tensor]): Point-wise semantic
label of each batch.
pts_instance_mask (list[torch.Tensor]): Point-wise instance
label of each batch.
bbox_preds (torch.Tensor): Bounding box predictions of vote head. bbox_preds (torch.Tensor): Bounding box predictions of vote head.
batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
gt_instances. It usually includes ``bboxes`` and ``labels``
attributes.
batch_pts_semantic_mask (list[tensor]): Semantic gt mask for
multiple images.
batch_pts_instance_mask (list[tensor]): Instance gt mask for
multiple images.
Returns: Returns:
tuple[torch.Tensor]: Targets of vote head. tuple[torch.Tensor]: Targets of vote head.
...@@ -374,40 +463,46 @@ class VoteHead(BaseModule): ...@@ -374,40 +463,46 @@ class VoteHead(BaseModule):
# find empty example # find empty example
valid_gt_masks = list() valid_gt_masks = list()
gt_num = list() gt_num = list()
for index in range(len(gt_labels_3d)): batch_gt_labels_3d = [
if len(gt_labels_3d[index]) == 0: gt_instances_3d.labels_3d
fake_box = gt_bboxes_3d[index].tensor.new_zeros( for gt_instances_3d in batch_gt_instances_3d
1, gt_bboxes_3d[index].tensor.shape[-1]) ]
gt_bboxes_3d[index] = gt_bboxes_3d[index].new_box(fake_box) batch_gt_bboxes_3d = [
gt_labels_3d[index] = gt_labels_3d[index].new_zeros(1) gt_instances_3d.bboxes_3d
valid_gt_masks.append(gt_labels_3d[index].new_zeros(1)) for gt_instances_3d in batch_gt_instances_3d
]
for index in range(len(batch_gt_labels_3d)):
if len(batch_gt_labels_3d[index]) == 0:
fake_box = batch_gt_bboxes_3d[index].tensor.new_zeros(
1, batch_gt_bboxes_3d[index].tensor.shape[-1])
batch_gt_bboxes_3d[index] = batch_gt_bboxes_3d[index].new_box(
fake_box)
batch_gt_labels_3d[index] = batch_gt_labels_3d[
index].new_zeros(1)
valid_gt_masks.append(batch_gt_labels_3d[index].new_zeros(1))
gt_num.append(1) gt_num.append(1)
else: else:
valid_gt_masks.append(gt_labels_3d[index].new_ones( valid_gt_masks.append(batch_gt_labels_3d[index].new_ones(
gt_labels_3d[index].shape)) batch_gt_labels_3d[index].shape))
gt_num.append(gt_labels_3d[index].shape[0]) gt_num.append(batch_gt_labels_3d[index].shape[0])
max_gt_num = max(gt_num) max_gt_num = max(gt_num)
if pts_semantic_mask is None:
pts_semantic_mask = [None for i in range(len(gt_labels_3d))]
pts_instance_mask = [None for i in range(len(gt_labels_3d))]
aggregated_points = [ aggregated_points = [
bbox_preds['aggregated_points'][i] bbox_preds['aggregated_points'][i]
for i in range(len(gt_labels_3d)) for i in range(len(batch_gt_labels_3d))
] ]
(vote_targets, vote_target_masks, size_class_targets, size_res_targets, (vote_targets, vote_target_masks, size_class_targets, size_res_targets,
dir_class_targets, dir_res_targets, center_targets, dir_class_targets, dir_res_targets, center_targets,
assigned_center_targets, mask_targets, objectness_targets, assigned_center_targets, mask_targets,
objectness_masks) = multi_apply(self.get_targets_single, points, objectness_targets, objectness_masks) = multi_apply(
gt_bboxes_3d, gt_labels_3d, self._get_targets_single, points, batch_gt_bboxes_3d,
pts_semantic_mask, pts_instance_mask, batch_gt_labels_3d, batch_pts_semantic_mask,
aggregated_points) batch_pts_instance_mask, aggregated_points)
# pad targets as original code of votenet. # pad targets as original code of votenet.
for index in range(len(gt_labels_3d)): for index in range(len(batch_gt_labels_3d)):
pad_num = max_gt_num - gt_labels_3d[index].shape[0] pad_num = max_gt_num - batch_gt_labels_3d[index].shape[0]
center_targets[index] = F.pad(center_targets[index], center_targets[index] = F.pad(center_targets[index],
(0, 0, 0, pad_num)) (0, 0, 0, pad_num))
valid_gt_masks[index] = F.pad(valid_gt_masks[index], (0, pad_num)) valid_gt_masks[index] = F.pad(valid_gt_masks[index], (0, pad_num))
...@@ -437,7 +532,7 @@ class VoteHead(BaseModule): ...@@ -437,7 +532,7 @@ class VoteHead(BaseModule):
valid_gt_masks, objectness_targets, objectness_weights, valid_gt_masks, objectness_targets, objectness_weights,
box_loss_weights, valid_gt_weights) box_loss_weights, valid_gt_weights)
def get_targets_single(self, def _get_targets_single(self,
points, points,
gt_bboxes_3d, gt_bboxes_3d,
gt_labels_3d, gt_labels_3d,
...@@ -501,7 +596,6 @@ class VoteHead(BaseModule): ...@@ -501,7 +596,6 @@ class VoteHead(BaseModule):
vote_targets = points.new_zeros([num_points, 3]) vote_targets = points.new_zeros([num_points, 3])
vote_target_masks = points.new_zeros([num_points], vote_target_masks = points.new_zeros([num_points],
dtype=torch.long) dtype=torch.long)
for i in torch.unique(pts_instance_mask): for i in torch.unique(pts_instance_mask):
indices = torch.nonzero( indices = torch.nonzero(
pts_instance_mask == i, as_tuple=False).squeeze(-1) pts_instance_mask == i, as_tuple=False).squeeze(-1)
...@@ -561,47 +655,63 @@ class VoteHead(BaseModule): ...@@ -561,47 +655,63 @@ class VoteHead(BaseModule):
dir_res_targets, center_targets, assigned_center_targets, dir_res_targets, center_targets, assigned_center_targets,
mask_targets.long(), objectness_targets, objectness_masks) mask_targets.long(), objectness_targets, objectness_masks)
def get_bboxes(self, def predict_by_feat(self,
points, points: List[torch.Tensor],
bbox_preds, bbox_preds_dict: dict,
input_metas, batch_input_metas: List[dict],
rescale=False, use_nms: bool = True,
use_nms=True): **kwargs) -> List[InstanceData]:
"""Generate bboxes from vote head predictions. """Generate bboxes from vote head predictions.
Args: Args:
points (torch.Tensor): Input points. points (List[torch.Tensor]): Input points of multiple samples.
bbox_preds (dict): Predictions from vote head. bbox_preds_dict (dict): Predictions from vote head.
input_metas (list[dict]): Point cloud and image's meta info. batch_input_metas (list[dict]): Each item
rescale (bool): Whether to rescale bboxes. contains the meta information of each sample.
use_nms (bool): Whether to apply NMS, skip nms postprocessing use_nms (bool): Whether to apply NMS, skip nms postprocessing
while using vote head in rpn stage. while using vote head in rpn stage.
Returns: Returns:
list[tuple[torch.Tensor]]: Bounding boxes, scores and labels. list[:obj:`InstanceData`]: List of processed predictions. Each
InstanceData cantains 3d Bounding boxes and corresponding
scores and labels.
""" """
# decode boxes # decode boxes
obj_scores = F.softmax(bbox_preds['obj_scores'], dim=-1)[..., -1] stack_points = torch.stack(points)
sem_scores = F.softmax(bbox_preds['sem_scores'], dim=-1) obj_scores = F.softmax(bbox_preds_dict['obj_scores'], dim=-1)[..., -1]
bbox3d = self.bbox_coder.decode(bbox_preds) sem_scores = F.softmax(bbox_preds_dict['sem_scores'], dim=-1)
bbox3d = self.bbox_coder.decode(bbox_preds_dict)
if use_nms:
batch_size = bbox3d.shape[0] batch_size = bbox3d.shape[0]
results = list() results_list = list()
for b in range(batch_size): for b in range(batch_size):
temp_results = InstanceData()
if use_nms:
bbox_selected, score_selected, labels = \ bbox_selected, score_selected, labels = \
self.multiclass_nms_single(obj_scores[b], sem_scores[b], self.multiclass_nms_single(obj_scores[b],
bbox3d[b], points[b, ..., :3], sem_scores[b],
input_metas[b]) bbox3d[b],
bbox = input_metas[b]['box_type_3d']( stack_points[b, ..., :3],
batch_input_metas[b])
bbox = batch_input_metas[b]['box_type_3d'](
bbox_selected, bbox_selected,
box_dim=bbox_selected.shape[-1], box_dim=bbox_selected.shape[-1],
with_yaw=self.bbox_coder.with_rot) with_yaw=self.bbox_coder.with_rot)
results.append((bbox, score_selected, labels)) temp_results.bboxes_3d = bbox
temp_results.scores_3d = score_selected
return results temp_results.labels_3d = labels
results_list.append(temp_results)
else: else:
return bbox3d bbox = batch_input_metas[b]['box_type_3d'](
bbox_selected,
box_dim=bbox_selected.shape[-1],
with_yaw=self.bbox_coder.with_rot)
temp_results.bboxes_3d = bbox
temp_results.obj_scores_3d = obj_scores[b]
temp_results.sem_scores_3d = obj_scores[b]
results_list.append(temp_results)
return results_list
def multiclass_nms_single(self, obj_scores, sem_scores, bbox, points, def multiclass_nms_single(self, obj_scores, sem_scores, bbox, points,
input_meta): input_meta):
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Union
from mmdet3d.core import Det3DDataSample from mmdet3d.core import Det3DDataSample
from mmdet3d.core.utils import (ForwardResults, InstanceList, OptConfigType, from mmdet3d.core.utils import (ForwardResults, InstanceList, OptConfigType,
OptMultiConfig, OptSampleList, SampleList) OptMultiConfig, OptSampleList, SampleList)
...@@ -24,8 +26,8 @@ class Base3DDetector(BaseDetector): ...@@ -24,8 +26,8 @@ class Base3DDetector(BaseDetector):
super().__init__(data_preprocessor=data_processor, init_cfg=init_cfg) super().__init__(data_preprocessor=data_processor, init_cfg=init_cfg)
def forward(self, def forward(self,
batch_inputs_dict: dict, inputs: Union[dict, List[dict]],
batch_data_samples: OptSampleList = None, data_samples: OptSampleList = None,
mode: str = 'tensor', mode: str = 'tensor',
**kwargs) -> ForwardResults: **kwargs) -> ForwardResults:
"""The unified entry for a forward process in both training and test. """The unified entry for a forward process in both training and test.
...@@ -43,10 +45,19 @@ class Base3DDetector(BaseDetector): ...@@ -43,10 +45,19 @@ class Base3DDetector(BaseDetector):
optimizer updating, which are done in the :meth:`train_step`. optimizer updating, which are done in the :meth:`train_step`.
Args: Args:
batch_inputs (torch.Tensor): The input tensor with shape inputs (dict | list[dict]): When it is a list[dict], the
(N, C, ...) in general. outer list indicate the test time augmentation. Each
batch_data_samples (list[:obj:`DetDataSample`], optional): The dict contains batch inputs
annotation data of every samples. Defaults to None. which include 'points' and 'imgs' keys.
- points (list[torch.Tensor]): Point cloud of each sample.
- imgs (torch.Tensor): Image tensor has shape (B, C, H, W).
data_samples (list[:obj:`DetDataSample`],
list[list[:obj:`DetDataSample`]], optional): The
annotation data of every samples. When it is a list[list], the
outer list indicate the test time augmentation, and the
inter list indicate the batch. Otherwise, the list simply
indicate the batch. Defaults to None.
mode (str): Return what kind of value. Defaults to 'tensor'. mode (str): Return what kind of value. Defaults to 'tensor'.
Returns: Returns:
...@@ -57,13 +68,20 @@ class Base3DDetector(BaseDetector): ...@@ -57,13 +68,20 @@ class Base3DDetector(BaseDetector):
- If ``mode="loss"``, return a dict of tensor. - If ``mode="loss"``, return a dict of tensor.
""" """
if mode == 'loss': if mode == 'loss':
return self.loss(batch_inputs_dict, batch_data_samples, **kwargs) return self.loss(inputs, data_samples, **kwargs)
elif mode == 'predict': elif mode == 'predict':
return self.predict(batch_inputs_dict, batch_data_samples, if isinstance(data_samples[0], list):
**kwargs) # aug test
assert len(data_samples[0]) == 1, 'Only support ' \
'batch_size 1 ' \
'in mmdet3d when ' \
'do the test' \
'time augmentation.'
return self.aug_test(inputs, data_samples, **kwargs)
else:
return self.predict(inputs, data_samples, **kwargs)
elif mode == 'tensor': elif mode == 'tensor':
return self._forward(batch_inputs_dict, batch_data_samples, return self._forward(inputs, data_samples, **kwargs)
**kwargs)
else: else:
raise RuntimeError(f'Invalid mode "{mode}". ' raise RuntimeError(f'Invalid mode "{mode}". '
'Only supports loss, predict and tensor mode') 'Only supports loss, predict and tensor mode')
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import torch from typing import Dict, List, Optional, Union
from mmdet3d.core import bbox3d2result, merge_aug_bboxes_3d from mmengine import InstanceData
from torch import Tensor
from mmdet3d.core import Det3DDataSample, merge_aug_bboxes_3d
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
from .single_stage import SingleStage3DDetector from .single_stage import SingleStage3DDetector
@MODELS.register_module() @MODELS.register_module()
class VoteNet(SingleStage3DDetector): class VoteNet(SingleStage3DDetector):
r"""`VoteNet <https://arxiv.org/pdf/1904.09664.pdf>`_ for 3D detection.""" r"""`VoteNet <https://arxiv.org/pdf/1904.09664.pdf>`_ for 3D detection.
Args:
backbone (dict): Config dict of detector's backbone.
bbox_head (dict, optional): Config dict of box head. Defaults to None.
train_cfg (dict, optional): Config dict of training hyper-parameters.
Defaults to None.
test_cfg (dict, optional): Config dict of test hyper-parameters.
Defaults to None.
init_cfg (dict, optional): the config to control the
initialization. Default to None.
data_preprocessor (dict or ConfigDict, optional): The pre-process
config of :class:`BaseDataPreprocessor`. it usually includes,
``pad_size_divisor``, ``pad_value``, ``mean`` and ``std``.
"""
def __init__(self, def __init__(self,
backbone, backbone: dict,
bbox_head=None, bbox_head: Optional[dict] = None,
train_cfg=None, train_cfg: Optional[dict] = None,
test_cfg=None, test_cfg: Optional[dict] = None,
init_cfg=None, init_cfg: Optional[dict] = None,
pretrained=None): data_preprocessor: Optional[dict] = None,
**kwargs):
super(VoteNet, self).__init__( super(VoteNet, self).__init__(
backbone=backbone, backbone=backbone,
bbox_head=bbox_head, bbox_head=bbox_head,
train_cfg=train_cfg, train_cfg=train_cfg,
test_cfg=test_cfg, test_cfg=test_cfg,
init_cfg=None, init_cfg=init_cfg,
pretrained=pretrained) data_preprocessor=data_preprocessor,
**kwargs)
def forward_train(self,
points,
img_metas,
gt_bboxes_3d,
gt_labels_3d,
pts_semantic_mask=None,
pts_instance_mask=None,
gt_bboxes_ignore=None):
"""Forward of training.
def loss(self, batch_inputs_dict: Dict[str, Union[List, Tensor]],
batch_data_samples: List[Det3DDataSample],
**kwargs) -> List[Det3DDataSample]:
"""
Args: Args:
points (list[torch.Tensor]): Points of each batch. batch_inputs_dict (dict): The model input dict which include
img_metas (list): Image metas. 'points' keys.
gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): gt bboxes of each batch.
gt_labels_3d (list[torch.Tensor]): gt class labels of each batch. - points (list[torch.Tensor]): Point cloud of each sample.
pts_semantic_mask (list[torch.Tensor]): point-wise semantic
label of each batch. batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
pts_instance_mask (list[torch.Tensor]): point-wise instance Samples. It usually includes information such as
label of each batch. `gt_instance_3d`.
gt_bboxes_ignore (list[torch.Tensor]): Specify
which bounding.
Returns: Returns:
dict: Losses. dict[str, Tensor]: A dictionary of loss components.
""" """
points_cat = torch.stack(points) feat_dict = self.extract_feat(batch_inputs_dict)
points = batch_inputs_dict['points']
x = self.extract_feat(points_cat) losses = self.bbox_head.loss(points, feat_dict, batch_data_samples,
bbox_preds = self.bbox_head(x, self.train_cfg.sample_mod) **kwargs)
loss_inputs = (points, gt_bboxes_3d, gt_labels_3d, pts_semantic_mask,
pts_instance_mask, img_metas)
losses = self.bbox_head.loss(
bbox_preds, *loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
return losses return losses
def simple_test(self, points, img_metas, imgs=None, rescale=False): def predict(self, batch_inputs_dict: Dict[str, Optional[Tensor]],
batch_data_samples: List[Det3DDataSample],
**kwargs) -> List[Det3DDataSample]:
"""Forward of testing. """Forward of testing.
Args: Args:
points (list[torch.Tensor]): Points of each sample. batch_inputs_dict (dict): The model input dict which include
img_metas (list): Image metas. 'points' keys.
rescale (bool): Whether to rescale results.
- points (list[torch.Tensor]): Point cloud of each sample.
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance_3d`.
Returns: Returns:
list: Predicted 3d boxes. list[:obj:`Det3DDataSample`]: Detection results of the
input sample. Each Det3DDataSample usually contain
'pred_instances_3d'. And the ``pred_instances_3d`` usually
contains following keys.
- scores_3d (Tensor): Classification scores, has a shape
(num_instances, )
- labels_3d (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes_3d (:obj:`BaseInstance3DBoxes`): Prediction of bboxes,
contains a tensor with shape (num_instances, 7).
""" """
points_cat = torch.stack(points) feats_dict = self.extract_feat(batch_inputs_dict)
points = batch_inputs_dict['points']
x = self.extract_feat(points_cat) results_list = self.bbox_head.predict(points, feats_dict,
bbox_preds = self.bbox_head(x, self.test_cfg.sample_mod) batch_data_samples, **kwargs)
bbox_list = self.bbox_head.get_bboxes( data_3d_samples = self.convert_to_datasample(results_list)
points_cat, bbox_preds, img_metas, rescale=rescale) return data_3d_samples
bbox_results = [
bbox3d2result(bboxes, scores, labels) def aug_test(self, aug_inputs_list: List[dict],
for bboxes, scores, labels in bbox_list aug_data_samples: List[List[dict]], **kwargs):
] """Test with augmentation.
return bbox_results
def aug_test(self, points, img_metas, imgs=None, rescale=False):
"""Test with augmentation."""
points_cat = [torch.stack(pts) for pts in points]
feats = self.extract_feats(points_cat, img_metas)
# only support aug_test for one sample
aug_bboxes = []
for x, pts_cat, img_meta in zip(feats, points_cat, img_metas):
bbox_preds = self.bbox_head(x, self.test_cfg.sample_mod)
bbox_list = self.bbox_head.get_bboxes(
pts_cat, bbox_preds, img_meta, rescale=rescale)
bbox_list = [
dict(boxes_3d=bboxes, scores_3d=scores, labels_3d=labels)
for bboxes, scores, labels in bbox_list
]
aug_bboxes.append(bbox_list[0])
Batch size always is 1 when do the augtest.
Args:
aug_inputs_list (List[dict]): The list indicate same data
under differecnt augmentation.
aug_data_samples (List[List[dict]]): The outer list
indicate different augmentation, and the inter
list indicate the batch size.
"""
num_augs = len(aug_inputs_list)
if num_augs == 1:
return self.predict(aug_inputs_list[0], aug_data_samples[0])
batch_size = len(aug_data_samples[0])
assert batch_size == 1
multi_aug_results = []
for aug_id in range(num_augs):
batch_inputs_dict = aug_inputs_list[aug_id]
batch_data_samples = aug_data_samples[aug_id]
feats_dict = self.extract_feat(batch_inputs_dict)
points = batch_inputs_dict['points']
results_list = self.bbox_head.predict(points, feats_dict,
batch_data_samples, **kwargs)
multi_aug_results.append(results_list[0])
aug_input_metas_list = []
for aug_index in range(num_augs):
metainfo = aug_data_samples[aug_id][0].metainfo
aug_input_metas_list.append(metainfo)
aug_results_list = [item.to_dict() for item in multi_aug_results]
# after merging, bboxes will be rescaled to the original image size # after merging, bboxes will be rescaled to the original image size
merged_bboxes = merge_aug_bboxes_3d(aug_bboxes, img_metas, merged_results_dict = merge_aug_bboxes_3d(aug_results_list,
aug_input_metas_list,
self.bbox_head.test_cfg) self.bbox_head.test_cfg)
return [merged_bboxes] merged_results = InstanceData(**merged_results_dict)
data_3d_samples = self.convert_to_datasample([merged_results])
return data_3d_samples
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment