Commit 3c57cc41 authored by jshilong's avatar jshilong Committed by ChaimZhu
Browse files

[Refactor] Refactort the interface of 3DSSD

parent bd73d3b9
...@@ -53,8 +53,9 @@ train_pipeline = [ ...@@ -53,8 +53,9 @@ train_pipeline = [
# 3DSSD can get a higher performance without this transform # 3DSSD can get a higher performance without this transform
# dict(type='BackgroundPointsFilter', bbox_enlarge_range=(0.5, 2.0, 0.5)), # dict(type='BackgroundPointsFilter', bbox_enlarge_range=(0.5, 2.0, 0.5)),
dict(type='PointSample', num_points=16384), dict(type='PointSample', num_points=16384),
dict(type='DefaultFormatBundle3D', class_names=class_names), dict(
dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d']) type='Pack3DDetInputs',
keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
] ]
test_pipeline = [ test_pipeline = [
...@@ -79,22 +80,14 @@ test_pipeline = [ ...@@ -79,22 +80,14 @@ test_pipeline = [
dict( dict(
type='PointsRangeFilter', point_cloud_range=point_cloud_range), type='PointsRangeFilter', point_cloud_range=point_cloud_range),
dict(type='PointSample', num_points=16384), dict(type='PointSample', num_points=16384),
dict( ]),
type='DefaultFormatBundle3D', dict(type='Pack3DDetInputs', keys=['points'])
class_names=class_names,
with_label=False),
dict(type='Collect3D', keys=['points'])
])
] ]
data = dict( train_dataloader = dict(
samples_per_gpu=4, batch_size=4, dataset=dict(dataset=dict(pipeline=train_pipeline, )))
workers_per_gpu=4, test_dataloader = dict(dataset=dict(pipeline=test_pipeline))
train=dict(dataset=dict(pipeline=train_pipeline)), val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))
evaluation = dict(interval=2)
# model settings # model settings
model = dict( model = dict(
...@@ -105,17 +98,24 @@ model = dict( ...@@ -105,17 +98,24 @@ model = dict(
# optimizer # optimizer
lr = 0.002 # max learning rate lr = 0.002 # max learning rate
optimizer = dict(type='AdamW', lr=lr, weight_decay=0) optim_wrapper = dict(
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2)) type='OptimWrapper',
lr_config = dict(policy='step', warmup=None, step=[45, 60]) optimizer=dict(type='AdamW', lr=lr, weight_decay=0.),
# runtime settings clip_grad=dict(max_norm=35, norm_type=2),
runner = dict(type='EpochBasedRunner', max_epochs=80) )
# yapf:disable # training schedule for 1x
log_config = dict( train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=80, val_interval=2)
interval=30, val_cfg = dict(type='ValLoop')
hooks=[ test_cfg = dict(type='TestLoop')
dict(type='TextLoggerHook'),
dict(type='TensorboardLoggerHook') # learning rate
]) param_scheduler = [
# yapf:enable dict(
type='MultiStepLR',
begin=0,
end=80,
by_epoch=True,
milestones=[45, 60],
gamma=0.1)
]
...@@ -69,9 +69,9 @@ test_pipeline = [ ...@@ -69,9 +69,9 @@ test_pipeline = [
translation_std=[0, 0, 0]), translation_std=[0, 0, 0]),
dict(type='RandomFlip3D'), dict(type='RandomFlip3D'),
dict( dict(
type='PointsRangeFilter', point_cloud_range=point_cloud_range), type='PointsRangeFilter', point_cloud_range=point_cloud_range)
dict(type='Pack3DDetInputs', keys=['points']), ]),
]) dict(type='Pack3DDetInputs', keys=['points'])
] ]
# construct a pipeline for data and gt loading in show function # construct a pipeline for data and gt loading in show function
# please keep its loading function consistent with test_pipeline (e.g. client) # please keep its loading function consistent with test_pipeline (e.g. client)
...@@ -82,7 +82,7 @@ eval_pipeline = [ ...@@ -82,7 +82,7 @@ eval_pipeline = [
load_dim=4, load_dim=4,
use_dim=4, use_dim=4,
file_client_args=file_client_args), file_client_args=file_client_args),
dict(type='Pack3DDetInputs', keys=['points']), dict(type='Pack3DDetInputs', keys=['points'])
] ]
train_dataloader = dict( train_dataloader = dict(
batch_size=6, batch_size=6,
......
model = dict( model = dict(
type='SSD3DNet', type='SSD3DNet',
data_preprocessor=dict(type='Det3DDataPreprocessor'),
backbone=dict( backbone=dict(
type='PointNet2SAMSG', type='PointNet2SAMSG',
in_channels=4, in_channels=4,
...@@ -20,7 +21,6 @@ model = dict( ...@@ -20,7 +21,6 @@ model = dict(
normalize_xyz=False)), normalize_xyz=False)),
bbox_head=dict( bbox_head=dict(
type='SSD3DHead', type='SSD3DHead',
in_channels=256,
vote_module_cfg=dict( vote_module_cfg=dict(
in_channels=256, in_channels=256,
num_points=256, num_points=256,
...@@ -48,30 +48,29 @@ model = dict( ...@@ -48,30 +48,29 @@ model = dict(
conv_cfg=dict(type='Conv1d'), conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.1), norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.1),
bias=True), bias=True),
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.1),
objectness_loss=dict( objectness_loss=dict(
type='CrossEntropyLoss', type='mmdet.CrossEntropyLoss',
use_sigmoid=True, use_sigmoid=True,
reduction='sum', reduction='sum',
loss_weight=1.0), loss_weight=1.0),
center_loss=dict( center_loss=dict(
type='SmoothL1Loss', reduction='sum', loss_weight=1.0), type='mmdet.SmoothL1Loss', reduction='sum', loss_weight=1.0),
dir_class_loss=dict( dir_class_loss=dict(
type='CrossEntropyLoss', reduction='sum', loss_weight=1.0), type='mmdet.CrossEntropyLoss', reduction='sum', loss_weight=1.0),
dir_res_loss=dict( dir_res_loss=dict(
type='SmoothL1Loss', reduction='sum', loss_weight=1.0), type='mmdet.SmoothL1Loss', reduction='sum', loss_weight=1.0),
size_res_loss=dict( size_res_loss=dict(
type='SmoothL1Loss', reduction='sum', loss_weight=1.0), type='mmdet.SmoothL1Loss', reduction='sum', loss_weight=1.0),
corner_loss=dict( corner_loss=dict(
type='SmoothL1Loss', reduction='sum', loss_weight=1.0), type='mmdet.SmoothL1Loss', reduction='sum', loss_weight=1.0),
vote_loss=dict(type='SmoothL1Loss', reduction='sum', loss_weight=1.0)), vote_loss=dict(
type='mmdet.SmoothL1Loss', reduction='sum', loss_weight=1.0)),
# model training and testing settings # model training and testing settings
train_cfg=dict( train_cfg=dict(
sample_mod='spec', pos_distance_thr=10.0, expand_dims_length=0.05), sample_mode='spec', pos_distance_thr=10.0, expand_dims_length=0.05),
test_cfg=dict( test_cfg=dict(
nms_cfg=dict(type='nms', iou_thr=0.1), nms_cfg=dict(type='nms', iou_thr=0.1),
sample_mod='spec', sample_mode='spec',
score_thr=0.0, score_thr=0.0,
per_class_proposal=True, per_class_proposal=True,
max_output_num=100)) max_output_num=100))
...@@ -57,9 +57,9 @@ test_pipeline = [ ...@@ -57,9 +57,9 @@ test_pipeline = [
translation_std=[0, 0, 0]), translation_std=[0, 0, 0]),
dict(type='RandomFlip3D'), dict(type='RandomFlip3D'),
dict( dict(
type='PointsRangeFilter', point_cloud_range=point_cloud_range), type='PointsRangeFilter', point_cloud_range=point_cloud_range)
dict(type='Pack3DDetInputs', keys=['points']) ]),
]) dict(type='Pack3DDetInputs', keys=['points'])
] ]
# construct a pipeline for data and gt loading in show function # construct a pipeline for data and gt loading in show function
# please keep its loading function consistent with test_pipeline (e.g. client) # please keep its loading function consistent with test_pipeline (e.g. client)
......
# TODO refactor the config of sunrgbd
_base_ = [ _base_ = [
'../_base_/datasets/sunrgbd-3d-10class.py', '../_base_/models/votenet.py', '../_base_/datasets/sunrgbd-3d-10class.py', '../_base_/models/votenet.py',
'../_base_/schedules/schedule_3x.py', '../_base_/default_runtime.py' '../_base_/schedules/schedule_3x.py', '../_base_/default_runtime.py'
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Tuple, Union
import torch import torch
from mmcv import ConfigDict
from mmcv.ops.nms import batched_nms from mmcv.ops.nms import batched_nms
from mmcv.runner import force_fp32 from mmengine import InstanceData
from torch import Tensor
from torch.nn import functional as F from torch.nn import functional as F
from mmdet3d.core.bbox.structures import (DepthInstance3DBoxes, from mmdet3d.core.bbox.structures import (DepthInstance3DBoxes,
...@@ -9,6 +13,7 @@ from mmdet3d.core.bbox.structures import (DepthInstance3DBoxes, ...@@ -9,6 +13,7 @@ from mmdet3d.core.bbox.structures import (DepthInstance3DBoxes,
rotation_3d_in_axis) rotation_3d_in_axis)
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
from mmdet.core import multi_apply from mmdet.core import multi_apply
from ...core import BaseInstance3DBoxes
from ..builder import build_loss from ..builder import build_loss
from .vote_head import VoteHead from .vote_head import VoteHead
...@@ -21,7 +26,6 @@ class SSD3DHead(VoteHead): ...@@ -21,7 +26,6 @@ class SSD3DHead(VoteHead):
num_classes (int): The number of class. num_classes (int): The number of class.
bbox_coder (:obj:`BaseBBoxCoder`): Bbox coder for encoding and bbox_coder (:obj:`BaseBBoxCoder`): Bbox coder for encoding and
decoding boxes. decoding boxes.
in_channels (int): The number of input feature channel.
train_cfg (dict): Config for training. train_cfg (dict): Config for training.
test_cfg (dict): Config for testing. test_cfg (dict): Config for testing.
vote_module_cfg (dict): Config of VoteModule for point-wise votes. vote_module_cfg (dict): Config of VoteModule for point-wise votes.
...@@ -41,25 +45,21 @@ class SSD3DHead(VoteHead): ...@@ -41,25 +45,21 @@ class SSD3DHead(VoteHead):
""" """
def __init__(self, def __init__(self,
num_classes, num_classes: int,
bbox_coder, bbox_coder: Union[ConfigDict, dict],
in_channels=256, train_cfg: Optional[dict] = None,
train_cfg=None, test_cfg: Optional[dict] = None,
test_cfg=None, vote_module_cfg: Optional[dict] = None,
vote_module_cfg=None, vote_aggregation_cfg: Optional[dict] = None,
vote_aggregation_cfg=None, pred_layer_cfg: Optional[dict] = None,
pred_layer_cfg=None, objectness_loss: Optional[dict] = None,
conv_cfg=dict(type='Conv1d'), center_loss: Optional[dict] = None,
norm_cfg=dict(type='BN1d'), dir_class_loss: Optional[dict] = None,
act_cfg=dict(type='ReLU'), dir_res_loss: Optional[dict] = None,
objectness_loss=None, size_res_loss: Optional[dict] = None,
center_loss=None, corner_loss: Optional[dict] = None,
dir_class_loss=None, vote_loss: Optional[dict] = None,
dir_res_loss=None, init_cfg: Optional[dict] = None) -> None:
size_res_loss=None,
corner_loss=None,
vote_loss=None,
init_cfg=None):
super(SSD3DHead, self).__init__( super(SSD3DHead, self).__init__(
num_classes, num_classes,
bbox_coder, bbox_coder,
...@@ -68,8 +68,6 @@ class SSD3DHead(VoteHead): ...@@ -68,8 +68,6 @@ class SSD3DHead(VoteHead):
vote_module_cfg=vote_module_cfg, vote_module_cfg=vote_module_cfg,
vote_aggregation_cfg=vote_aggregation_cfg, vote_aggregation_cfg=vote_aggregation_cfg,
pred_layer_cfg=pred_layer_cfg, pred_layer_cfg=pred_layer_cfg,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
objectness_loss=objectness_loss, objectness_loss=objectness_loss,
center_loss=center_loss, center_loss=center_loss,
dir_class_loss=dir_class_loss, dir_class_loss=dir_class_loss,
...@@ -78,24 +76,23 @@ class SSD3DHead(VoteHead): ...@@ -78,24 +76,23 @@ class SSD3DHead(VoteHead):
size_res_loss=size_res_loss, size_res_loss=size_res_loss,
semantic_loss=None, semantic_loss=None,
init_cfg=init_cfg) init_cfg=init_cfg)
self.corner_loss = build_loss(corner_loss) self.corner_loss = build_loss(corner_loss)
self.vote_loss = build_loss(vote_loss) self.vote_loss = build_loss(vote_loss)
self.num_candidates = vote_module_cfg['num_points'] self.num_candidates = vote_module_cfg['num_points']
def _get_cls_out_channels(self): def _get_cls_out_channels(self) -> int:
"""Return the channel number of classification outputs.""" """Return the channel number of classification outputs."""
# Class numbers (k) + objectness (1) # Class numbers (k) + objectness (1)
return self.num_classes return self.num_classes
def _get_reg_out_channels(self): def _get_reg_out_channels(self) -> int:
"""Return the channel number of regression outputs.""" """Return the channel number of regression outputs."""
# Bbox classification and regression # Bbox classification and regression
# (center residual (3), size regression (3) # (center residual (3), size regression (3)
# heading class+residual (num_dir_bins*2)), # heading class+residual (num_dir_bins*2)),
return 3 + 3 + self.num_dir_bins * 2 return 3 + 3 + self.num_dir_bins * 2
def _extract_input(self, feat_dict): def _extract_input(self, feat_dict: dict) -> Tuple:
"""Extract inputs from features dictionary. """Extract inputs from features dictionary.
Args: Args:
...@@ -112,86 +109,87 @@ class SSD3DHead(VoteHead): ...@@ -112,86 +109,87 @@ class SSD3DHead(VoteHead):
return seed_points, seed_features, seed_indices return seed_points, seed_features, seed_indices
@force_fp32(apply_to=('bbox_preds', )) def loss_by_feat(
def loss(self, self,
bbox_preds, points: List[torch.Tensor],
points, bbox_preds_dict: dict,
gt_bboxes_3d, batch_gt_instances_3d: List[InstanceData],
gt_labels_3d, batch_pts_semantic_mask: Optional[List[torch.Tensor]] = None,
pts_semantic_mask=None, batch_pts_instance_mask: Optional[List[torch.Tensor]] = None,
pts_instance_mask=None, batch_input_metas: List[dict] = None,
img_metas=None, ret_target: bool = False,
gt_bboxes_ignore=None): **kwargs) -> dict:
"""Compute loss. """Compute loss.
Args: Args:
bbox_preds (dict): Predictions from forward of SSD3DHead.
points (list[torch.Tensor]): Input points. points (list[torch.Tensor]): Input points.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth bbox_preds_dict (dict): Predictions from forward of vote head.
bboxes of each sample. batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
gt_labels_3d (list[torch.Tensor]): Labels of each sample. gt_instances. It usually includes ``bboxes_3d`` and
pts_semantic_mask (list[torch.Tensor]): Point-wise ``labels_3d`` attributes.
semantic mask. batch_pts_semantic_mask (list[tensor]): Semantic mask
pts_instance_mask (list[torch.Tensor]): Point-wise of points cloud. Defaults to None. Defaults to None.
instance mask. batch_pts_semantic_mask (list[tensor]): Instance mask
img_metas (list[dict]): Contain pcd and img's meta info. of points cloud. Defaults to None. Defaults to None.
gt_bboxes_ignore (list[torch.Tensor]): Specify batch_input_metas (list[dict]): Contain pcd and img's meta info.
which bounding. ret_target (bool): Return targets or not. Defaults to False.
Returns: Returns:
dict: Losses of 3DSSD. dict: Losses of 3DSSD.
""" """
targets = self.get_targets(points, gt_bboxes_3d, gt_labels_3d,
pts_semantic_mask, pts_instance_mask, targets = self.get_targets(points, bbox_preds_dict,
bbox_preds) batch_gt_instances_3d,
batch_pts_semantic_mask,
batch_pts_instance_mask)
(vote_targets, center_targets, size_res_targets, dir_class_targets, (vote_targets, center_targets, size_res_targets, dir_class_targets,
dir_res_targets, mask_targets, centerness_targets, corner3d_targets, dir_res_targets, mask_targets, centerness_targets, corner3d_targets,
vote_mask, positive_mask, negative_mask, centerness_weights, vote_mask, positive_mask, negative_mask, centerness_weights,
box_loss_weights, heading_res_loss_weight) = targets box_loss_weights, heading_res_loss_weight) = targets
# calculate centerness loss # calculate centerness loss
centerness_loss = self.objectness_loss( centerness_loss = self.loss_objectness(
bbox_preds['obj_scores'].transpose(2, 1), bbox_preds_dict['obj_scores'].transpose(2, 1),
centerness_targets, centerness_targets,
weight=centerness_weights) weight=centerness_weights)
# calculate center loss # calculate center loss
center_loss = self.center_loss( center_loss = self.loss_center(
bbox_preds['center_offset'], bbox_preds_dict['center_offset'],
center_targets, center_targets,
weight=box_loss_weights.unsqueeze(-1)) weight=box_loss_weights.unsqueeze(-1))
# calculate direction class loss # calculate direction class loss
dir_class_loss = self.dir_class_loss( dir_class_loss = self.loss_dir_class(
bbox_preds['dir_class'].transpose(1, 2), bbox_preds_dict['dir_class'].transpose(1, 2),
dir_class_targets, dir_class_targets,
weight=box_loss_weights) weight=box_loss_weights)
# calculate direction residual loss # calculate direction residual loss
dir_res_loss = self.dir_res_loss( dir_res_loss = self.loss_dir_res(
bbox_preds['dir_res_norm'], bbox_preds_dict['dir_res_norm'],
dir_res_targets.unsqueeze(-1).repeat(1, 1, self.num_dir_bins), dir_res_targets.unsqueeze(-1).repeat(1, 1, self.num_dir_bins),
weight=heading_res_loss_weight) weight=heading_res_loss_weight)
# calculate size residual loss # calculate size residual loss
size_loss = self.size_res_loss( size_loss = self.loss_size_res(
bbox_preds['size'], bbox_preds_dict['size'],
size_res_targets, size_res_targets,
weight=box_loss_weights.unsqueeze(-1)) weight=box_loss_weights.unsqueeze(-1))
# calculate corner loss # calculate corner loss
one_hot_dir_class_targets = dir_class_targets.new_zeros( one_hot_dir_class_targets = dir_class_targets.new_zeros(
bbox_preds['dir_class'].shape) bbox_preds_dict['dir_class'].shape)
one_hot_dir_class_targets.scatter_(2, dir_class_targets.unsqueeze(-1), one_hot_dir_class_targets.scatter_(2, dir_class_targets.unsqueeze(-1),
1) 1)
pred_bbox3d = self.bbox_coder.decode( pred_bbox3d = self.bbox_coder.decode(
dict( dict(
center=bbox_preds['center'], center=bbox_preds_dict['center'],
dir_res=bbox_preds['dir_res'], dir_res=bbox_preds_dict['dir_res'],
dir_class=one_hot_dir_class_targets, dir_class=one_hot_dir_class_targets,
size=bbox_preds['size'])) size=bbox_preds_dict['size']))
pred_bbox3d = pred_bbox3d.reshape(-1, pred_bbox3d.shape[-1]) pred_bbox3d = pred_bbox3d.reshape(-1, pred_bbox3d.shape[-1])
pred_bbox3d = img_metas[0]['box_type_3d']( pred_bbox3d = batch_input_metas[0]['box_type_3d'](
pred_bbox3d.clone(), pred_bbox3d.clone(),
box_dim=pred_bbox3d.shape[-1], box_dim=pred_bbox3d.shape[-1],
with_yaw=self.bbox_coder.with_rot, with_yaw=self.bbox_coder.with_rot,
...@@ -204,7 +202,7 @@ class SSD3DHead(VoteHead): ...@@ -204,7 +202,7 @@ class SSD3DHead(VoteHead):
# calculate vote loss # calculate vote loss
vote_loss = self.vote_loss( vote_loss = self.vote_loss(
bbox_preds['vote_offset'].transpose(1, 2), bbox_preds_dict['vote_offset'].transpose(1, 2),
vote_targets, vote_targets,
weight=vote_mask.unsqueeze(-1)) weight=vote_mask.unsqueeze(-1))
...@@ -219,57 +217,74 @@ class SSD3DHead(VoteHead): ...@@ -219,57 +217,74 @@ class SSD3DHead(VoteHead):
return losses return losses
def get_targets(self, def get_targets(
points, self,
gt_bboxes_3d, points: List[Tensor],
gt_labels_3d, bbox_preds_dict: dict = None,
pts_semantic_mask=None, batch_gt_instances_3d: List[InstanceData] = None,
pts_instance_mask=None, batch_pts_semantic_mask: List[torch.Tensor] = None,
bbox_preds=None): batch_pts_instance_mask: List[torch.Tensor] = None,
"""Generate targets of ssd3d head. ) -> Tuple[Tensor]:
"""Generate targets of 3DSSD head.
Args: Args:
points (list[torch.Tensor]): Points of each batch. points (list[torch.Tensor]): Points of each batch.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth bbox_preds_dict (dict): Bounding box predictions of
bboxes of each batch. vote head. Defaults to None.
gt_labels_3d (list[torch.Tensor]): Labels of each batch. batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
pts_semantic_mask (list[torch.Tensor]): Point-wise semantic gt_instances. It usually includes ``bboxes`` and ``labels``
label of each batch. attributes. Defaults to None.
pts_instance_mask (list[torch.Tensor]): Point-wise instance batch_pts_semantic_mask (list[tensor]): Semantic gt mask for
label of each batch. point clouds. Defaults to None.
bbox_preds (torch.Tensor): Bounding box predictions of ssd3d head. batch_pts_instance_mask (list[tensor]): Instance gt mask for
point clouds. Defaults to None.
Returns: Returns:
tuple[torch.Tensor]: Targets of ssd3d head. tuple[torch.Tensor]: Targets of 3DSSD head.
""" """
# find empty example batch_gt_labels_3d = [
for index in range(len(gt_labels_3d)): gt_instances_3d.labels_3d
if len(gt_labels_3d[index]) == 0: for gt_instances_3d in batch_gt_instances_3d
fake_box = gt_bboxes_3d[index].tensor.new_zeros( ]
1, gt_bboxes_3d[index].tensor.shape[-1]) batch_gt_bboxes_3d = [
gt_bboxes_3d[index] = gt_bboxes_3d[index].new_box(fake_box) gt_instances_3d.bboxes_3d
gt_labels_3d[index] = gt_labels_3d[index].new_zeros(1) for gt_instances_3d in batch_gt_instances_3d
]
if pts_semantic_mask is None: # find empty example
pts_semantic_mask = [None for i in range(len(gt_labels_3d))] for index in range(len(batch_gt_labels_3d)):
pts_instance_mask = [None for i in range(len(gt_labels_3d))] if len(batch_gt_labels_3d[index]) == 0:
fake_box = batch_gt_bboxes_3d[index].tensor.new_zeros(
1, batch_gt_bboxes_3d[index].tensor.shape[-1])
batch_gt_bboxes_3d[index] = batch_gt_bboxes_3d[index].new_box(
fake_box)
batch_gt_labels_3d[index] = batch_gt_labels_3d[
index].new_zeros(1)
if batch_pts_semantic_mask is None:
batch_pts_semantic_mask = [
None for _ in range(len(batch_gt_labels_3d))
]
batch_pts_instance_mask = [
None for _ in range(len(batch_gt_labels_3d))
]
aggregated_points = [ aggregated_points = [
bbox_preds['aggregated_points'][i] bbox_preds_dict['aggregated_points'][i]
for i in range(len(gt_labels_3d)) for i in range(len(batch_gt_labels_3d))
] ]
seed_points = [ seed_points = [
bbox_preds['seed_points'][i, :self.num_candidates].detach() bbox_preds_dict['seed_points'][i, :self.num_candidates].detach()
for i in range(len(gt_labels_3d)) for i in range(len(batch_gt_labels_3d))
] ]
(vote_targets, center_targets, size_res_targets, dir_class_targets, (vote_targets, center_targets, size_res_targets, dir_class_targets,
dir_res_targets, mask_targets, centerness_targets, corner3d_targets, dir_res_targets, mask_targets, centerness_targets, corner3d_targets,
vote_mask, positive_mask, negative_mask) = multi_apply( vote_mask, positive_mask, negative_mask) = multi_apply(
self.get_targets_single, points, gt_bboxes_3d, gt_labels_3d, self.get_targets_single, points, batch_gt_bboxes_3d,
pts_semantic_mask, pts_instance_mask, aggregated_points, batch_gt_labels_3d, batch_pts_semantic_mask,
seed_points) batch_pts_instance_mask, aggregated_points, seed_points)
center_targets = torch.stack(center_targets) center_targets = torch.stack(center_targets)
positive_mask = torch.stack(positive_mask) positive_mask = torch.stack(positive_mask)
...@@ -283,7 +298,7 @@ class SSD3DHead(VoteHead): ...@@ -283,7 +298,7 @@ class SSD3DHead(VoteHead):
vote_targets = torch.stack(vote_targets) vote_targets = torch.stack(vote_targets)
vote_mask = torch.stack(vote_mask) vote_mask = torch.stack(vote_mask)
center_targets -= bbox_preds['aggregated_points'] center_targets -= bbox_preds_dict['aggregated_points']
centerness_weights = (positive_mask + centerness_weights = (positive_mask +
negative_mask).unsqueeze(-1).repeat( negative_mask).unsqueeze(-1).repeat(
...@@ -308,13 +323,14 @@ class SSD3DHead(VoteHead): ...@@ -308,13 +323,14 @@ class SSD3DHead(VoteHead):
heading_res_loss_weight) heading_res_loss_weight)
def get_targets_single(self, def get_targets_single(self,
points, points: Tensor,
gt_bboxes_3d, gt_bboxes_3d: BaseInstance3DBoxes,
gt_labels_3d, gt_labels_3d: Tensor,
pts_semantic_mask=None, pts_semantic_mask: Optional[Tensor] = None,
pts_instance_mask=None, pts_instance_mask: Optional[Tensor] = None,
aggregated_points=None, aggregated_points: Optional[Tensor] = None,
seed_points=None): seed_points: Optional[Tensor] = None,
**kwargs):
"""Generate targets of ssd3d head for single batch. """Generate targets of ssd3d head for single batch.
Args: Args:
...@@ -440,41 +456,50 @@ class SSD3DHead(VoteHead): ...@@ -440,41 +456,50 @@ class SSD3DHead(VoteHead):
centerness_targets, corner3d_targets, vote_mask, positive_mask, centerness_targets, corner3d_targets, vote_mask, positive_mask,
negative_mask) negative_mask)
def get_bboxes(self, points, bbox_preds, input_metas, rescale=False): def predict_by_feat(self, points: List[torch.Tensor],
"""Generate bboxes from 3DSSD head predictions. bbox_preds_dict: dict, batch_input_metas: List[dict],
**kwargs) -> List[InstanceData]:
"""Generate bboxes from vote head predictions.
Args: Args:
points (torch.Tensor): Input points. points (List[torch.Tensor]): Input points of multiple samples.
bbox_preds (dict): Predictions from sdd3d head. bbox_preds_dict (dict): Predictions from vote head.
input_metas (list[dict]): Point cloud and image's meta info. batch_input_metas (list[dict]): Each item
rescale (bool): Whether to rescale bboxes. contains the meta information of each sample.
Returns: Returns:
list[tuple[torch.Tensor]]: Bounding boxes, scores and labels. list[:obj:`InstanceData`]: List of processed predictions. Each
InstanceData cantains 3d Bounding boxes and corresponding
scores and labels.
""" """
# decode boxes # decode boxes
sem_scores = F.sigmoid(bbox_preds['obj_scores']).transpose(1, 2) sem_scores = F.sigmoid(bbox_preds_dict['obj_scores']).transpose(1, 2)
obj_scores = sem_scores.max(-1)[0] obj_scores = sem_scores.max(-1)[0]
bbox3d = self.bbox_coder.decode(bbox_preds) bbox3d = self.bbox_coder.decode(bbox_preds_dict)
batch_size = bbox3d.shape[0] batch_size = bbox3d.shape[0]
results = list() points = torch.stack(points)
results_list = []
for b in range(batch_size): for b in range(batch_size):
temp_results = InstanceData()
bbox_selected, score_selected, labels = self.multiclass_nms_single( bbox_selected, score_selected, labels = self.multiclass_nms_single(
obj_scores[b], sem_scores[b], bbox3d[b], points[b, ..., :3], obj_scores[b], sem_scores[b], bbox3d[b], points[b, ..., :3],
input_metas[b]) batch_input_metas[b])
bbox = input_metas[b]['box_type_3d']( bbox = batch_input_metas[b]['box_type_3d'](
bbox_selected.clone(), bbox_selected.clone(),
box_dim=bbox_selected.shape[-1], box_dim=bbox_selected.shape[-1],
with_yaw=self.bbox_coder.with_rot) with_yaw=self.bbox_coder.with_rot)
results.append((bbox, score_selected, labels))
return results temp_results.bboxes_3d = bbox
temp_results.scores_3d = score_selected
temp_results.labels_3d = labels
results_list.append(temp_results)
return results_list
def multiclass_nms_single(self, obj_scores, sem_scores, bbox, points, def multiclass_nms_single(self, obj_scores: Tensor, sem_scores: Tensor,
input_meta): bbox: Tensor, points: Tensor,
input_meta: dict) -> Tuple[Tensor]:
"""Multi-class nms in single batch. """Multi-class nms in single batch.
Args: Args:
...@@ -538,7 +563,8 @@ class SSD3DHead(VoteHead): ...@@ -538,7 +563,8 @@ class SSD3DHead(VoteHead):
return bbox_selected, score_selected, labels return bbox_selected, score_selected, labels
def _assign_targets_by_points_inside(self, bboxes_3d, points): def _assign_targets_by_points_inside(self, bboxes_3d: BaseInstance3DBoxes,
points: Tensor) -> Tuple:
"""Compute assignment by checking whether point is inside bbox. """Compute assignment by checking whether point is inside bbox.
Args: Args:
......
...@@ -16,11 +16,11 @@ class SSD3DNet(VoteNet): ...@@ -16,11 +16,11 @@ class SSD3DNet(VoteNet):
train_cfg=None, train_cfg=None,
test_cfg=None, test_cfg=None,
init_cfg=None, init_cfg=None,
pretrained=None): **kwargs):
super(SSD3DNet, self).__init__( super(SSD3DNet, self).__init__(
backbone=backbone, backbone=backbone,
bbox_head=bbox_head, bbox_head=bbox_head,
train_cfg=train_cfg, train_cfg=train_cfg,
test_cfg=test_cfg, test_cfg=test_cfg,
init_cfg=init_cfg, init_cfg=init_cfg,
pretrained=pretrained) **kwargs)
import unittest
import torch
from mmengine import DefaultScope
from mmdet3d.registry import MODELS
from tests.utils.model_utils import (_create_detector_inputs,
_get_detector_cfg, _setup_seed)
class Test3DSSD(unittest.TestCase):
def test_3dssd(self):
import mmdet3d.models
assert hasattr(mmdet3d.models, 'SSD3DNet')
DefaultScope.get_instance('test_ssd3d', scope_name='mmdet3d')
_setup_seed(0)
voxel_net_cfg = _get_detector_cfg('3dssd/3dssd_4x4_kitti-3d-car.py')
model = MODELS.build(voxel_net_cfg)
num_gt_instance = 3
data = [
_create_detector_inputs(
num_gt_instance=num_gt_instance, num_classes=1)
]
if torch.cuda.is_available():
model = model.cuda()
# test simple_test
with torch.no_grad():
batch_inputs, data_samples = model.data_preprocessor(
data, True)
results = model.forward(
batch_inputs, data_samples, mode='predict')
self.assertEqual(len(results), len(data))
self.assertIn('bboxes_3d', results[0].pred_instances_3d)
self.assertIn('scores_3d', results[0].pred_instances_3d)
self.assertIn('labels_3d', results[0].pred_instances_3d)
losses = model.forward(batch_inputs, data_samples, mode='loss')
self.assertGreater(losses['centerness_loss'], 0)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment