Commit 522cc20d authored by VVsssssk's avatar VVsssssk Committed by ChaimZhu
Browse files

[Refactor]Refactor ShapeAwareHead and FreeAnchor3DHead

parent 3c57cc41
...@@ -32,7 +32,7 @@ model = dict( ...@@ -32,7 +32,7 @@ model = dict(
layer_strides=[2, 2, 2], layer_strides=[2, 2, 2],
out_channels=[64, 128, 256]), out_channels=[64, 128, 256]),
pts_neck=dict( pts_neck=dict(
type='FPN', type='mmdet.FPN',
norm_cfg=dict(type='naiveSyncBN2d', eps=1e-3, momentum=0.01), norm_cfg=dict(type='naiveSyncBN2d', eps=1e-3, momentum=0.01),
act_cfg=dict(type='ReLU'), act_cfg=dict(type='ReLU'),
in_channels=[64, 128, 256], in_channels=[64, 128, 256],
......
...@@ -47,6 +47,6 @@ param_scheduler = [ ...@@ -47,6 +47,6 @@ param_scheduler = [
] ]
# runtime settings # runtime settings
train_cfg = dict(by_epoch=True, max_epochs=20, val_interval=1) train_cfg = dict(by_epoch=True, max_epochs=20, val_interval=20)
val_cfg = dict() val_cfg = dict()
test_cfg = dict() test_cfg = dict()
...@@ -8,7 +8,7 @@ optim_wrapper = dict( ...@@ -8,7 +8,7 @@ optim_wrapper = dict(
clip_grad=dict(max_norm=35, norm_type=2)) clip_grad=dict(max_norm=35, norm_type=2))
# training schedule for 2x # training schedule for 2x
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=24, val_interval=1) train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=24, val_interval=24)
val_cfg = dict(type='ValLoop') val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop') test_cfg = dict(type='TestLoop')
......
...@@ -34,14 +34,16 @@ model = dict( ...@@ -34,14 +34,16 @@ model = dict(
dir_offset=-0.7854, # -pi / 4 dir_offset=-0.7854, # -pi / 4
bbox_coder=dict(type='DeltaXYZWLHRBBoxCoder', code_size=9), bbox_coder=dict(type='DeltaXYZWLHRBBoxCoder', code_size=9),
loss_cls=dict( loss_cls=dict(
type='FocalLoss', type='mmdet.FocalLoss',
use_sigmoid=True, use_sigmoid=True,
gamma=2.0, gamma=2.0,
alpha=0.25, alpha=0.25,
loss_weight=1.0), loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=0.8), loss_bbox=dict(
type='mmdet.SmoothL1Loss', beta=1.0 / 9.0, loss_weight=0.8),
loss_dir=dict( loss_dir=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.2)), type='mmdet.CrossEntropyLoss', use_sigmoid=False,
loss_weight=0.2)),
# model training and testing settings # model training and testing settings
train_cfg=dict( train_cfg=dict(
pts=dict(code_weight=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.25, 0.25]))) pts=dict(code_weight=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.25, 0.25])))
...@@ -60,11 +60,25 @@ train_pipeline = [ ...@@ -60,11 +60,25 @@ train_pipeline = [
dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range), dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
dict(type='ObjectNameFilter', classes=class_names), dict(type='ObjectNameFilter', classes=class_names),
dict(type='PointShuffle'), dict(type='PointShuffle'),
dict(type='DefaultFormatBundle3D', class_names=class_names), dict(
dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d']) type='Pack3DDetInputs',
keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
] ]
data = dict(train=dict(pipeline=train_pipeline)) train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
lr_config = dict(step=[28, 34]) train_cfg = dict(max_epochs=36, val_interval=36)
runner = dict(max_epochs=36) param_scheduler = [
evaluation = dict(interval=36) dict(
type='LinearLR',
start_factor=1.0 / 1000,
by_epoch=False,
begin=0,
end=1000),
dict(
type='MultiStepLR',
begin=0,
end=24,
by_epoch=True,
milestones=[28, 34],
gamma=0.1)
]
...@@ -60,11 +60,25 @@ train_pipeline = [ ...@@ -60,11 +60,25 @@ train_pipeline = [
dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range), dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
dict(type='ObjectNameFilter', classes=class_names), dict(type='ObjectNameFilter', classes=class_names),
dict(type='PointShuffle'), dict(type='PointShuffle'),
dict(type='DefaultFormatBundle3D', class_names=class_names),
dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d']) dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
] ]
train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
data = dict(train=dict(pipeline=train_pipeline)) train_cfg = dict(max_epochs=36, val_interval=36)
lr_config = dict(step=[28, 34])
runner = dict(max_epochs=36) # learning rate
evaluation = dict(interval=36) param_scheduler = [
dict(
type='LinearLR',
start_factor=1.0 / 1000,
by_epoch=False,
begin=0,
end=1000),
dict(
type='MultiStepLR',
begin=0,
end=36,
by_epoch=True,
milestones=[28, 34],
gamma=0.1)
]
...@@ -41,9 +41,3 @@ model = dict( ...@@ -41,9 +41,3 @@ model = dict(
], ],
rotations=[0, 1.57], rotations=[0, 1.57],
reshape_out=True))) reshape_out=True)))
# For Lyft dataset, we usually evaluate the model at the end of training.
# Since the models are trained by 24 epochs by default, we set evaluation
# interval to be 24. Please change the interval accordingly if you do not
# use a default schedule.
train_cfg = dict(val_interval=24)
...@@ -29,8 +29,9 @@ train_pipeline = [ ...@@ -29,8 +29,9 @@ train_pipeline = [
dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range), dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range),
dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range), dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
dict(type='PointShuffle'), dict(type='PointShuffle'),
dict(type='DefaultFormatBundle3D', class_names=class_names), dict(
dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d']) type='Pack3DDetInputs',
keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
] ]
test_pipeline = [ test_pipeline = [
dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=5, use_dim=5), dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=5, use_dim=5),
...@@ -48,20 +49,18 @@ test_pipeline = [ ...@@ -48,20 +49,18 @@ test_pipeline = [
translation_std=[0, 0, 0]), translation_std=[0, 0, 0]),
dict(type='RandomFlip3D'), dict(type='RandomFlip3D'),
dict( dict(
type='PointsRangeFilter', point_cloud_range=point_cloud_range), type='PointsRangeFilter', point_cloud_range=point_cloud_range)
dict( ]),
type='DefaultFormatBundle3D', dict(type='Pack3DDetInputs', keys=['points'])
class_names=class_names,
with_label=False),
dict(type='Collect3D', keys=['points'])
])
] ]
data = dict( train_dataloader = dict(
samples_per_gpu=2, batch_size=2,
workers_per_gpu=4, num_workers=4,
train=dict(pipeline=train_pipeline, classes=class_names), dataset=dict(pipeline=train_pipeline, metainfo=dict(CLASSES=class_names)))
val=dict(pipeline=test_pipeline, classes=class_names), test_dataloader = dict(
test=dict(pipeline=test_pipeline, classes=class_names)) dataset=dict(pipeline=test_pipeline, metainfo=dict(CLASSES=class_names)))
val_dataloader = dict(
dataset=dict(pipeline=test_pipeline, metainfo=dict(CLASSES=class_names)))
# model settings # model settings
model = dict( model = dict(
...@@ -148,84 +147,86 @@ model = dict( ...@@ -148,84 +147,86 @@ model = dict(
dir_limit_offset=0, dir_limit_offset=0,
bbox_coder=dict(type='DeltaXYZWLHRBBoxCoder', code_size=9), bbox_coder=dict(type='DeltaXYZWLHRBBoxCoder', code_size=9),
loss_cls=dict( loss_cls=dict(
type='FocalLoss', type='mmdet.FocalLoss',
use_sigmoid=True, use_sigmoid=True,
gamma=2.0, gamma=2.0,
alpha=0.25, alpha=0.25,
loss_weight=1.0), loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0), loss_bbox=dict(
type='mmdet.SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0),
loss_dir=dict( loss_dir=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.2)), type='mmdet.CrossEntropyLoss', use_sigmoid=False,
loss_weight=0.2)),
# model training and testing settings # model training and testing settings
train_cfg=dict( train_cfg=dict(
_delete_=True, _delete_=True,
pts=dict( pts=dict(
assigner=[ assigner=[
dict( # bicycle dict( # bicycle
type='MaxIoUAssigner', type='Max3DIoUAssigner',
iou_calculator=dict(type='BboxOverlapsNearest3D'), iou_calculator=dict(type='BboxOverlapsNearest3D'),
pos_iou_thr=0.5, pos_iou_thr=0.5,
neg_iou_thr=0.35, neg_iou_thr=0.35,
min_pos_iou=0.35, min_pos_iou=0.35,
ignore_iof_thr=-1), ignore_iof_thr=-1),
dict( # motorcycle dict( # motorcycle
type='MaxIoUAssigner', type='Max3DIoUAssigner',
iou_calculator=dict(type='BboxOverlapsNearest3D'), iou_calculator=dict(type='BboxOverlapsNearest3D'),
pos_iou_thr=0.5, pos_iou_thr=0.5,
neg_iou_thr=0.3, neg_iou_thr=0.3,
min_pos_iou=0.3, min_pos_iou=0.3,
ignore_iof_thr=-1), ignore_iof_thr=-1),
dict( # pedestrian dict( # pedestrian
type='MaxIoUAssigner', type='Max3DIoUAssigner',
iou_calculator=dict(type='BboxOverlapsNearest3D'), iou_calculator=dict(type='BboxOverlapsNearest3D'),
pos_iou_thr=0.6, pos_iou_thr=0.6,
neg_iou_thr=0.4, neg_iou_thr=0.4,
min_pos_iou=0.4, min_pos_iou=0.4,
ignore_iof_thr=-1), ignore_iof_thr=-1),
dict( # traffic cone dict( # traffic cone
type='MaxIoUAssigner', type='Max3DIoUAssigner',
iou_calculator=dict(type='BboxOverlapsNearest3D'), iou_calculator=dict(type='BboxOverlapsNearest3D'),
pos_iou_thr=0.6, pos_iou_thr=0.6,
neg_iou_thr=0.4, neg_iou_thr=0.4,
min_pos_iou=0.4, min_pos_iou=0.4,
ignore_iof_thr=-1), ignore_iof_thr=-1),
dict( # barrier dict( # barrier
type='MaxIoUAssigner', type='Max3DIoUAssigner',
iou_calculator=dict(type='BboxOverlapsNearest3D'), iou_calculator=dict(type='BboxOverlapsNearest3D'),
pos_iou_thr=0.55, pos_iou_thr=0.55,
neg_iou_thr=0.4, neg_iou_thr=0.4,
min_pos_iou=0.4, min_pos_iou=0.4,
ignore_iof_thr=-1), ignore_iof_thr=-1),
dict( # car dict( # car
type='MaxIoUAssigner', type='Max3DIoUAssigner',
iou_calculator=dict(type='BboxOverlapsNearest3D'), iou_calculator=dict(type='BboxOverlapsNearest3D'),
pos_iou_thr=0.6, pos_iou_thr=0.6,
neg_iou_thr=0.45, neg_iou_thr=0.45,
min_pos_iou=0.45, min_pos_iou=0.45,
ignore_iof_thr=-1), ignore_iof_thr=-1),
dict( # truck dict( # truck
type='MaxIoUAssigner', type='Max3DIoUAssigner',
iou_calculator=dict(type='BboxOverlapsNearest3D'), iou_calculator=dict(type='BboxOverlapsNearest3D'),
pos_iou_thr=0.55, pos_iou_thr=0.55,
neg_iou_thr=0.4, neg_iou_thr=0.4,
min_pos_iou=0.4, min_pos_iou=0.4,
ignore_iof_thr=-1), ignore_iof_thr=-1),
dict( # trailer dict( # trailer
type='MaxIoUAssigner', type='Max3DIoUAssigner',
iou_calculator=dict(type='BboxOverlapsNearest3D'), iou_calculator=dict(type='BboxOverlapsNearest3D'),
pos_iou_thr=0.5, pos_iou_thr=0.5,
neg_iou_thr=0.35, neg_iou_thr=0.35,
min_pos_iou=0.35, min_pos_iou=0.35,
ignore_iof_thr=-1), ignore_iof_thr=-1),
dict( # bus dict( # bus
type='MaxIoUAssigner', type='Max3DIoUAssigner',
iou_calculator=dict(type='BboxOverlapsNearest3D'), iou_calculator=dict(type='BboxOverlapsNearest3D'),
pos_iou_thr=0.55, pos_iou_thr=0.55,
neg_iou_thr=0.4, neg_iou_thr=0.4,
min_pos_iou=0.4, min_pos_iou=0.4,
ignore_iof_thr=-1), ignore_iof_thr=-1),
dict( # construction vehicle dict( # construction vehicle
type='MaxIoUAssigner', type='Max3DIoUAssigner',
iou_calculator=dict(type='BboxOverlapsNearest3D'), iou_calculator=dict(type='BboxOverlapsNearest3D'),
pos_iou_thr=0.5, pos_iou_thr=0.5,
neg_iou_thr=0.35, neg_iou_thr=0.35,
......
...@@ -264,8 +264,8 @@ class Base3DDenseHead(BaseModule, metaclass=ABCMeta): ...@@ -264,8 +264,8 @@ class Base3DDenseHead(BaseModule, metaclass=ABCMeta):
cfg: ConfigDict, cfg: ConfigDict,
rescale: bool = False, rescale: bool = False,
**kwargs) -> InstanceData: **kwargs) -> InstanceData:
"""Transform a single image's features extracted from the head into """Transform a single points sample's features extracted from the head
bbox results. into bbox results.
Args: Args:
cls_score_list (list[Tensor]): Box scores from all scale cls_score_list (list[Tensor]): Box scores from all scale
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List
import torch import torch
from mmcv.runner import force_fp32 from torch import Tensor
from torch.nn import functional as F from torch.nn import functional as F
from mmdet3d.core.bbox import bbox_overlaps_nearest_3d from mmdet3d.core.bbox import bbox_overlaps_nearest_3d
from mmdet3d.core.utils import InstanceList, OptInstanceList
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
from .anchor3d_head import Anchor3DHead from .anchor3d_head import Anchor3DHead
from .train_mixins import get_direction_target from .train_mixins import get_direction_target
...@@ -29,27 +32,26 @@ class FreeAnchor3DHead(Anchor3DHead): ...@@ -29,27 +32,26 @@ class FreeAnchor3DHead(Anchor3DHead):
""" # noqa: E501 """ # noqa: E501
def __init__(self, def __init__(self,
pre_anchor_topk=50, pre_anchor_topk: int = 50,
bbox_thr=0.6, bbox_thr: float = 0.6,
gamma=2.0, gamma: float = 2.0,
alpha=0.5, alpha: float = 0.5,
init_cfg=None, init_cfg: dict = None,
**kwargs): **kwargs) -> None:
super().__init__(init_cfg=init_cfg, **kwargs) super().__init__(init_cfg=init_cfg, **kwargs)
self.pre_anchor_topk = pre_anchor_topk self.pre_anchor_topk = pre_anchor_topk
self.bbox_thr = bbox_thr self.bbox_thr = bbox_thr
self.gamma = gamma self.gamma = gamma
self.alpha = alpha self.alpha = alpha
@force_fp32(apply_to=('cls_scores', 'bbox_preds', 'dir_cls_preds')) def loss_by_feat(
def loss(self, self,
cls_scores, cls_scores: List[Tensor],
bbox_preds, bbox_preds: List[Tensor],
dir_cls_preds, dir_cls_preds: List[Tensor],
gt_bboxes, batch_gt_instances_3d: InstanceList,
gt_labels, batch_input_metas: List[dict],
input_metas, batch_gt_instances_ignore: OptInstanceList = None) -> Dict:
gt_bboxes_ignore=None):
"""Calculate loss of FreeAnchor head. """Calculate loss of FreeAnchor head.
Args: Args:
...@@ -59,11 +61,14 @@ class FreeAnchor3DHead(Anchor3DHead): ...@@ -59,11 +61,14 @@ class FreeAnchor3DHead(Anchor3DHead):
different samples different samples
dir_cls_preds (list[torch.Tensor]): Direction predictions of dir_cls_preds (list[torch.Tensor]): Direction predictions of
different samples different samples
gt_bboxes (list[:obj:`BaseInstance3DBoxes`]): Ground truth boxes. batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
gt_labels (list[torch.Tensor]): Ground truth labels. gt_instances. It usually includes ``bboxes_3d`` and
input_metas (list[dict]): List of input meta information. ``labels_3d`` attributes.
gt_bboxes_ignore (list[:obj:`BaseInstance3DBoxes`], optional): batch_input_metas (list[dict]): Contain pcd and img's meta info.
Ground truth boxes that should be ignored. Defaults to None. batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
Batch of gt_instances_ignore. It includes ``bboxes`` attribute
data that is ignored during training and testing.
Defaults to None.
Returns: Returns:
dict[str, torch.Tensor]: Loss items. dict[str, torch.Tensor]: Loss items.
...@@ -72,10 +77,10 @@ class FreeAnchor3DHead(Anchor3DHead): ...@@ -72,10 +77,10 @@ class FreeAnchor3DHead(Anchor3DHead):
- negative_bag_loss (torch.Tensor): Loss of negative samples. - negative_bag_loss (torch.Tensor): Loss of negative samples.
""" """
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
assert len(featmap_sizes) == self.anchor_generator.num_levels assert len(featmap_sizes) == self.prior_generator.num_levels
anchor_list = self.get_anchors(featmap_sizes, input_metas) anchor_list = self.get_anchors(featmap_sizes, batch_input_metas)
anchors = [torch.cat(anchor) for anchor in anchor_list] mlvl_anchors = [torch.cat(anchor) for anchor in anchor_list]
# concatenate each level # concatenate each level
cls_scores = [ cls_scores = [
...@@ -98,24 +103,24 @@ class FreeAnchor3DHead(Anchor3DHead): ...@@ -98,24 +103,24 @@ class FreeAnchor3DHead(Anchor3DHead):
bbox_preds = torch.cat(bbox_preds, dim=1) bbox_preds = torch.cat(bbox_preds, dim=1)
dir_cls_preds = torch.cat(dir_cls_preds, dim=1) dir_cls_preds = torch.cat(dir_cls_preds, dim=1)
cls_prob = torch.sigmoid(cls_scores) cls_probs = torch.sigmoid(cls_scores)
box_prob = [] box_prob = []
num_pos = 0 num_pos = 0
positive_losses = [] positive_losses = []
for _, (anchors_, gt_labels_, gt_bboxes_, cls_prob_, bbox_preds_, for _, (anchors, gt_instance_3d, cls_prob, bbox_pred,
dir_cls_preds_) in enumerate( dir_cls_pred) in enumerate(
zip(anchors, gt_labels, gt_bboxes, cls_prob, bbox_preds, zip(mlvl_anchors, batch_gt_instances_3d, cls_probs,
dir_cls_preds)): bbox_preds, dir_cls_preds)):
gt_bboxes_ = gt_bboxes_.tensor.to(anchors_.device)
gt_bboxes = gt_instance_3d.bboxes_3d.tensor.to(anchors.device)
gt_labels = gt_instance_3d.labels_3d.to(anchors.device)
with torch.no_grad(): with torch.no_grad():
# box_localization: a_{j}^{loc}, shape: [j, 4] # box_localization: a_{j}^{loc}, shape: [j, 4]
pred_boxes = self.bbox_coder.decode(anchors_, bbox_preds_) pred_boxes = self.bbox_coder.decode(anchors, bbox_pred)
# object_box_iou: IoU_{ij}^{loc}, shape: [i, j] # object_box_iou: IoU_{ij}^{loc}, shape: [i, j]
object_box_iou = bbox_overlaps_nearest_3d( object_box_iou = bbox_overlaps_nearest_3d(
gt_bboxes_, pred_boxes) gt_bboxes, pred_boxes)
# object_box_prob: P{a_{j} -> b_{i}}, shape: [i, j] # object_box_prob: P{a_{j} -> b_{i}}, shape: [i, j]
t1 = self.bbox_thr t1 = self.bbox_thr
...@@ -125,9 +130,9 @@ class FreeAnchor3DHead(Anchor3DHead): ...@@ -125,9 +130,9 @@ class FreeAnchor3DHead(Anchor3DHead):
min=0, max=1) min=0, max=1)
# object_cls_box_prob: P{a_{j} -> b_{i}}, shape: [i, c, j] # object_cls_box_prob: P{a_{j} -> b_{i}}, shape: [i, c, j]
num_obj = gt_labels_.size(0) num_obj = gt_labels.size(0)
indices = torch.stack( indices = torch.stack(
[torch.arange(num_obj).type_as(gt_labels_), gt_labels_], [torch.arange(num_obj).type_as(gt_labels), gt_labels],
dim=0) dim=0)
object_cls_box_prob = torch.sparse_coo_tensor( object_cls_box_prob = torch.sparse_coo_tensor(
...@@ -147,11 +152,11 @@ class FreeAnchor3DHead(Anchor3DHead): ...@@ -147,11 +152,11 @@ class FreeAnchor3DHead(Anchor3DHead):
indices = torch.nonzero(box_cls_prob, as_tuple=False).t_() indices = torch.nonzero(box_cls_prob, as_tuple=False).t_()
if indices.numel() == 0: if indices.numel() == 0:
image_box_prob = torch.zeros( image_box_prob = torch.zeros(
anchors_.size(0), anchors.size(0),
self.num_classes).type_as(object_box_prob) self.num_classes).type_as(object_box_prob)
else: else:
nonzero_box_prob = torch.where( nonzero_box_prob = torch.where(
(gt_labels_.unsqueeze(dim=-1) == indices[0]), (gt_labels.unsqueeze(dim=-1) == indices[0]),
object_box_prob[:, indices[1]], object_box_prob[:, indices[1]],
torch.tensor( torch.tensor(
[0]).type_as(object_box_prob)).max(dim=0).values [0]).type_as(object_box_prob)).max(dim=0).values
...@@ -160,14 +165,13 @@ class FreeAnchor3DHead(Anchor3DHead): ...@@ -160,14 +165,13 @@ class FreeAnchor3DHead(Anchor3DHead):
image_box_prob = torch.sparse_coo_tensor( image_box_prob = torch.sparse_coo_tensor(
indices.flip([0]), indices.flip([0]),
nonzero_box_prob, nonzero_box_prob,
size=(anchors_.size(0), self.num_classes)).to_dense() size=(anchors.size(0), self.num_classes)).to_dense()
# end # end
box_prob.append(image_box_prob) box_prob.append(image_box_prob)
# construct bags for objects # construct bags for objects
match_quality_matrix = bbox_overlaps_nearest_3d( match_quality_matrix = bbox_overlaps_nearest_3d(gt_bboxes, anchors)
gt_bboxes_, anchors_)
_, matched = torch.topk( _, matched = torch.topk(
match_quality_matrix, match_quality_matrix,
self.pre_anchor_topk, self.pre_anchor_topk,
...@@ -177,15 +181,15 @@ class FreeAnchor3DHead(Anchor3DHead): ...@@ -177,15 +181,15 @@ class FreeAnchor3DHead(Anchor3DHead):
# matched_cls_prob: P_{ij}^{cls} # matched_cls_prob: P_{ij}^{cls}
matched_cls_prob = torch.gather( matched_cls_prob = torch.gather(
cls_prob_[matched], 2, cls_prob[matched], 2,
gt_labels_.view(-1, 1, 1).repeat(1, self.pre_anchor_topk, gt_labels.view(-1, 1, 1).repeat(1, self.pre_anchor_topk,
1)).squeeze(2) 1)).squeeze(2)
# matched_box_prob: P_{ij}^{loc} # matched_box_prob: P_{ij}^{loc}
matched_anchors = anchors_[matched] matched_anchors = anchors[matched]
matched_object_targets = self.bbox_coder.encode( matched_object_targets = self.bbox_coder.encode(
matched_anchors, matched_anchors,
gt_bboxes_.unsqueeze(dim=1).expand_as(matched_anchors)) gt_bboxes.unsqueeze(dim=1).expand_as(matched_anchors))
# direction classification loss # direction classification loss
loss_dir = None loss_dir = None
...@@ -198,15 +202,16 @@ class FreeAnchor3DHead(Anchor3DHead): ...@@ -198,15 +202,16 @@ class FreeAnchor3DHead(Anchor3DHead):
self.dir_limit_offset, self.dir_limit_offset,
one_hot=False) one_hot=False)
loss_dir = self.loss_dir( loss_dir = self.loss_dir(
dir_cls_preds_[matched].transpose(-2, -1), dir_cls_pred[matched].transpose(-2, -1),
matched_dir_targets, matched_dir_targets,
reduction_override='none') reduction_override='none')
# generate bbox weights # generate bbox weights
if self.diff_rad_by_sin: if self.diff_rad_by_sin:
bbox_preds_[matched], matched_object_targets = \ bbox_preds_clone = bbox_pred.clone()
bbox_preds_clone[matched], matched_object_targets = \
self.add_sin_difference( self.add_sin_difference(
bbox_preds_[matched], matched_object_targets) bbox_preds_clone[matched], matched_object_targets)
bbox_weights = matched_anchors.new_ones(matched_anchors.size()) bbox_weights = matched_anchors.new_ones(matched_anchors.size())
# Use pop is not right, check performance # Use pop is not right, check performance
code_weight = self.train_cfg.get('code_weight', None) code_weight = self.train_cfg.get('code_weight', None)
...@@ -214,7 +219,7 @@ class FreeAnchor3DHead(Anchor3DHead): ...@@ -214,7 +219,7 @@ class FreeAnchor3DHead(Anchor3DHead):
bbox_weights = bbox_weights * bbox_weights.new_tensor( bbox_weights = bbox_weights * bbox_weights.new_tensor(
code_weight) code_weight)
loss_bbox = self.loss_bbox( loss_bbox = self.loss_bbox(
bbox_preds_[matched], bbox_preds_clone[matched],
matched_object_targets, matched_object_targets,
bbox_weights, bbox_weights,
reduction_override='none').sum(-1) reduction_override='none').sum(-1)
...@@ -224,7 +229,7 @@ class FreeAnchor3DHead(Anchor3DHead): ...@@ -224,7 +229,7 @@ class FreeAnchor3DHead(Anchor3DHead):
matched_box_prob = torch.exp(-loss_bbox) matched_box_prob = torch.exp(-loss_bbox)
# positive_losses: {-log( Mean-max(P_{ij}^{cls} * P_{ij}^{loc}) )} # positive_losses: {-log( Mean-max(P_{ij}^{cls} * P_{ij}^{loc}) )}
num_pos += len(gt_bboxes_) num_pos += len(gt_bboxes)
positive_losses.append( positive_losses.append(
self.positive_bag_loss(matched_cls_prob, matched_box_prob)) self.positive_bag_loss(matched_cls_prob, matched_box_prob))
...@@ -244,7 +249,8 @@ class FreeAnchor3DHead(Anchor3DHead): ...@@ -244,7 +249,8 @@ class FreeAnchor3DHead(Anchor3DHead):
} }
return losses return losses
def positive_bag_loss(self, matched_cls_prob, matched_box_prob): def positive_bag_loss(self, matched_cls_prob: Tensor,
matched_box_prob: Tensor) -> Tensor:
"""Generate positive bag loss. """Generate positive bag loss.
Args: Args:
...@@ -266,7 +272,7 @@ class FreeAnchor3DHead(Anchor3DHead): ...@@ -266,7 +272,7 @@ class FreeAnchor3DHead(Anchor3DHead):
return self.alpha * F.binary_cross_entropy( return self.alpha * F.binary_cross_entropy(
bag_prob, torch.ones_like(bag_prob), reduction='none') bag_prob, torch.ones_like(bag_prob), reduction='none')
def negative_bag_loss(self, cls_prob, box_prob): def negative_bag_loss(self, cls_prob: Tensor, box_prob: Tensor) -> Tensor:
"""Generate negative bag loss. """Generate negative bag loss.
Args: Args:
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import warnings import warnings
from typing import Dict, List, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule from mmengine.data import InstanceData
from mmengine.model import BaseModule
from torch import Tensor
from torch import nn as nn from torch import nn as nn
from mmdet3d.core import box3d_multiclass_nms, limit_period, xywhr2xyxyr from mmdet3d.core import box3d_multiclass_nms, limit_period, xywhr2xyxyr
from mmdet3d.core.utils import InstanceList, OptInstanceList
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
from mmdet.core import multi_apply from mmdet.core import multi_apply
from ..builder import build_head from ..builder import build_head
...@@ -33,29 +37,30 @@ class BaseShapeHead(BaseModule): ...@@ -33,29 +37,30 @@ class BaseShapeHead(BaseModule):
in_channels (int): Input channels for convolutional layers. in_channels (int): Input channels for convolutional layers.
shared_conv_channels (tuple, optional): Channels for shared shared_conv_channels (tuple, optional): Channels for shared
convolutional layers. Default: (64, 64). convolutional layers. Default: (64, 64).
shared_conv_strides (tuple, optional): Strides for shared shared_conv_strides (tuple): Strides for shared
convolutional layers. Default: (1, 1). convolutional layers. Default: (1, 1).
use_direction_classifier (bool, optional): Whether to use direction use_direction_classifier (bool): Whether to use direction
classifier. Default: True. classifier. Default: True.
conv_cfg (dict, optional): Config of conv layer. conv_cfg (dict): Config of conv layer.
Default: dict(type='Conv2d') Default: dict(type='Conv2d')
norm_cfg (dict, optional): Config of norm layer. norm_cfg (dict): Config of norm layer.
Default: dict(type='BN2d'). Default: dict(type='BN2d').
bias (bool | str, optional): Type of bias. Default: False. bias (bool | str): Type of bias. Default: False.
init_cfg (dict or list[dict], optional): Initialization config dict.
""" """
def __init__(self, def __init__(self,
num_cls, num_cls: int,
num_base_anchors, num_base_anchors: int,
box_code_size, box_code_size: int,
in_channels, in_channels: int,
shared_conv_channels=(64, 64), shared_conv_channels: Tuple = (64, 64),
shared_conv_strides=(1, 1), shared_conv_strides: Tuple = (1, 1),
use_direction_classifier=True, use_direction_classifier: bool = True,
conv_cfg=dict(type='Conv2d'), conv_cfg: Dict = dict(type='Conv2d'),
norm_cfg=dict(type='BN2d'), norm_cfg: Dict = dict(type='BN2d'),
bias=False, bias: bool = False,
init_cfg=None): init_cfg: Optional[dict] = None) -> None:
super().__init__(init_cfg=init_cfg) super().__init__(init_cfg=init_cfg)
self.num_cls = num_cls self.num_cls = num_cls
self.num_base_anchors = num_base_anchors self.num_base_anchors = num_base_anchors
...@@ -122,7 +127,7 @@ class BaseShapeHead(BaseModule): ...@@ -122,7 +127,7 @@ class BaseShapeHead(BaseModule):
bias_prob=0.01) bias_prob=0.01)
]) ])
def forward(self, x): def forward(self, x: Tensor) -> Dict:
"""Forward function for SmallHead. """Forward function for SmallHead.
Args: Args:
...@@ -171,13 +176,16 @@ class ShapeAwareHead(Anchor3DHead): ...@@ -171,13 +176,16 @@ class ShapeAwareHead(Anchor3DHead):
Args: Args:
tasks (dict): Shape-aware groups of multi-class objects. tasks (dict): Shape-aware groups of multi-class objects.
assign_per_class (bool, optional): Whether to do assignment for each assign_per_class (bool): Whether to do assignment for each
class. Default: True. class. Default: True.
kwargs (dict): Other arguments are the same as those in init_cfg (dict or list[dict], optional): Initialization config dict.
:class:`Anchor3DHead`.
""" """
def __init__(self, tasks, assign_per_class=True, init_cfg=None, **kwargs): def __init__(self,
tasks: Dict,
assign_per_class: bool = True,
init_cfg: Optional[dict] = None,
**kwargs) -> Dict:
self.tasks = tasks self.tasks = tasks
self.featmap_sizes = [] self.featmap_sizes = []
super().__init__( super().__init__(
...@@ -198,10 +206,10 @@ class ShapeAwareHead(Anchor3DHead): ...@@ -198,10 +206,10 @@ class ShapeAwareHead(Anchor3DHead):
self.heads = nn.ModuleList() self.heads = nn.ModuleList()
cls_ptr = 0 cls_ptr = 0
for task in self.tasks: for task in self.tasks:
sizes = self.anchor_generator.sizes[cls_ptr:cls_ptr + sizes = self.prior_generator.sizes[cls_ptr:cls_ptr +
task['num_class']] task['num_class']]
num_size = torch.tensor(sizes).reshape(-1, 3).size(0) num_size = torch.tensor(sizes).reshape(-1, 3).size(0)
num_rot = len(self.anchor_generator.rotations) num_rot = len(self.prior_generator.rotations)
num_base_anchors = num_rot * num_size num_base_anchors = num_rot * num_size
branch = dict( branch = dict(
type='BaseShapeHead', type='BaseShapeHead',
...@@ -214,7 +222,7 @@ class ShapeAwareHead(Anchor3DHead): ...@@ -214,7 +222,7 @@ class ShapeAwareHead(Anchor3DHead):
self.heads.append(build_head(branch)) self.heads.append(build_head(branch))
cls_ptr += task['num_class'] cls_ptr += task['num_class']
def forward_single(self, x): def forward_single(self, x: Tensor) -> Tuple[Tensor]:
"""Forward function on a single-scale feature map. """Forward function on a single-scale feature map.
Args: Args:
...@@ -241,15 +249,18 @@ class ShapeAwareHead(Anchor3DHead): ...@@ -241,15 +249,18 @@ class ShapeAwareHead(Anchor3DHead):
for i, task in enumerate(self.tasks): for i, task in enumerate(self.tasks):
for _ in range(task['num_class']): for _ in range(task['num_class']):
self.featmap_sizes.append(results[i]['featmap_size']) self.featmap_sizes.append(results[i]['featmap_size'])
assert len(self.featmap_sizes) == len(self.anchor_generator.ranges), \ assert len(self.featmap_sizes) == len(self.prior_generator.ranges), \
'Length of feature map sizes must be equal to length of ' + \ 'Length of feature map sizes must be equal to length of ' + \
'different ranges of anchor generator.' 'different ranges of anchor generator.'
return cls_score, bbox_pred, dir_cls_preds return cls_score, bbox_pred, dir_cls_preds
def loss_single(self, cls_score, bbox_pred, dir_cls_preds, labels, def loss_single(self, cls_score: Tensor, bbox_pred: Tensor,
label_weights, bbox_targets, bbox_weights, dir_targets, dir_cls_preds: Tensor, labels: Tensor,
dir_weights, num_total_samples): label_weights: Tensor, bbox_targets: Tensor,
bbox_weights: Tensor, dir_targets: Tensor,
dir_weights: Tensor,
num_total_samples: int) -> Tuple[Tensor]:
"""Calculate loss of Single-level results. """Calculate loss of Single-level results.
Args: Args:
...@@ -309,27 +320,30 @@ class ShapeAwareHead(Anchor3DHead): ...@@ -309,27 +320,30 @@ class ShapeAwareHead(Anchor3DHead):
return loss_cls, loss_bbox, loss_dir return loss_cls, loss_bbox, loss_dir
def loss(self, def loss_by_feat(
cls_scores, self,
bbox_preds, cls_scores: List[Tensor],
dir_cls_preds, bbox_preds: List[Tensor],
gt_bboxes, dir_cls_preds: List[Tensor],
gt_labels, batch_gt_instances_3d: InstanceList,
input_metas, batch_input_metas: List[dict],
gt_bboxes_ignore=None): batch_gt_instances_ignore: OptInstanceList = None) -> Dict:
"""Calculate losses. """Calculate the loss based on the features extracted by the detection
head.
Args: Args:
cls_scores (list[torch.Tensor]): Multi-level class scores. cls_scores (list[torch.Tensor]): Multi-level class scores.
bbox_preds (list[torch.Tensor]): Multi-level bbox predictions. bbox_preds (list[torch.Tensor]): Multi-level bbox predictions.
dir_cls_preds (list[torch.Tensor]): Multi-level direction dir_cls_preds (list[torch.Tensor]): Multi-level direction
class predictions. class predictions.
gt_bboxes (list[:obj:`BaseInstance3DBoxes`]): Gt bboxes batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
of each sample. gt_instances. It usually includes ``bboxes_3d`` and
gt_labels (list[torch.Tensor]): Gt labels of each sample. ``labels_3d`` attributes.
input_metas (list[dict]): Contain pcd and img's meta info. batch_input_metas (list[dict]): Contain pcd and sample's meta info.
gt_bboxes_ignore (list[torch.Tensor]): Specify batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
which bounding. Batch of gt_instances_ignore. It includes ``bboxes`` attribute
data that is ignored during training and testing.
Defaults to None.
Returns: Returns:
dict[str, list[torch.Tensor]]: Classification, bbox, and dict[str, list[torch.Tensor]]: Classification, bbox, and
...@@ -342,13 +356,12 @@ class ShapeAwareHead(Anchor3DHead): ...@@ -342,13 +356,12 @@ class ShapeAwareHead(Anchor3DHead):
""" """
device = cls_scores[0].device device = cls_scores[0].device
anchor_list = self.get_anchors( anchor_list = self.get_anchors(
self.featmap_sizes, input_metas, device=device) self.featmap_sizes, batch_input_metas, device=device)
cls_reg_targets = self.anchor_target_3d( cls_reg_targets = self.anchor_target_3d(
anchor_list, anchor_list,
gt_bboxes, batch_gt_instances_3d,
input_metas, batch_input_metas,
gt_bboxes_ignore_list=gt_bboxes_ignore, batch_gt_instances_ignore=batch_gt_instances_ignore,
gt_labels_list=gt_labels,
num_classes=self.num_classes, num_classes=self.num_classes,
sampling=self.sampling) sampling=self.sampling)
...@@ -376,21 +389,22 @@ class ShapeAwareHead(Anchor3DHead): ...@@ -376,21 +389,22 @@ class ShapeAwareHead(Anchor3DHead):
return dict( return dict(
loss_cls=losses_cls, loss_bbox=losses_bbox, loss_dir=losses_dir) loss_cls=losses_cls, loss_bbox=losses_bbox, loss_dir=losses_dir)
def get_bboxes(self, def predict_by_feat(self,
cls_scores, cls_scores: List[Tensor],
bbox_preds, bbox_preds: List[Tensor],
dir_cls_preds, dir_cls_preds: List[Tensor],
input_metas, batch_input_metas: List[dict],
cfg=None, cfg: Optional[dict] = None,
rescale=False): rescale: List[Tensor] = False) -> List[tuple]:
"""Get bboxes of anchor head. """Transform a batch of output features extracted from the head into
bbox results.
Args: Args:
cls_scores (list[torch.Tensor]): Multi-level class scores. cls_scores (list[torch.Tensor]): Multi-level class scores.
bbox_preds (list[torch.Tensor]): Multi-level bbox predictions. bbox_preds (list[torch.Tensor]): Multi-level bbox predictions.
dir_cls_preds (list[torch.Tensor]): Multi-level direction dir_cls_preds (list[torch.Tensor]): Multi-level direction
class predictions. class predictions.
input_metas (list[dict]): Contain pcd and img's meta info. batch_input_metas (list[dict]): Contain pcd and img's meta info.
cfg (:obj:`ConfigDict`, optional): Training or testing config. cfg (:obj:`ConfigDict`, optional): Training or testing config.
Default: None. Default: None.
rescale (list[torch.Tensor], optional): Whether to rescale bbox. rescale (list[torch.Tensor], optional): Whether to rescale bbox.
...@@ -404,13 +418,13 @@ class ShapeAwareHead(Anchor3DHead): ...@@ -404,13 +418,13 @@ class ShapeAwareHead(Anchor3DHead):
num_levels = len(cls_scores) num_levels = len(cls_scores)
assert num_levels == 1, 'Only support single level inference.' assert num_levels == 1, 'Only support single level inference.'
device = cls_scores[0].device device = cls_scores[0].device
mlvl_anchors = self.anchor_generator.grid_anchors( mlvl_anchors = self.prior_generator.grid_anchors(
self.featmap_sizes, device=device) self.featmap_sizes, device=device)
# `anchor` is a list of anchors for different classes # `anchor` is a list of anchors for different classes
mlvl_anchors = [torch.cat(anchor, dim=0) for anchor in mlvl_anchors] mlvl_anchors = [torch.cat(anchor, dim=0) for anchor in mlvl_anchors]
result_list = [] result_list = []
for img_id in range(len(input_metas)): for img_id in range(len(batch_input_metas)):
cls_score_list = [ cls_score_list = [
cls_scores[i][img_id].detach() for i in range(num_levels) cls_scores[i][img_id].detach() for i in range(num_levels)
] ]
...@@ -421,22 +435,25 @@ class ShapeAwareHead(Anchor3DHead): ...@@ -421,22 +435,25 @@ class ShapeAwareHead(Anchor3DHead):
dir_cls_preds[i][img_id].detach() for i in range(num_levels) dir_cls_preds[i][img_id].detach() for i in range(num_levels)
] ]
input_meta = input_metas[img_id] input_meta = batch_input_metas[img_id]
proposals = self.get_bboxes_single(cls_score_list, bbox_pred_list, proposals = self._predict_by_feat_single(cls_score_list,
dir_cls_pred_list, mlvl_anchors, bbox_pred_list,
input_meta, cfg, rescale) dir_cls_pred_list,
mlvl_anchors, input_meta,
cfg, rescale)
result_list.append(proposals) result_list.append(proposals)
return result_list return result_list
def get_bboxes_single(self, def _predict_by_feat_single(self,
cls_scores, cls_scores: Tensor,
bbox_preds, bbox_preds: Tensor,
dir_cls_preds, dir_cls_preds: Tensor,
mlvl_anchors, mlvl_anchors: List[Tensor],
input_meta, input_meta: List[dict],
cfg=None, cfg: Dict = None,
rescale=False): rescale: List[Tensor] = False):
"""Get bboxes of single branch. """Transform a single point's features extracted from the head into
bbox results.
Args: Args:
cls_scores (torch.Tensor): Class score in single batch. cls_scores (torch.Tensor): Class score in single batch.
...@@ -447,7 +464,7 @@ class ShapeAwareHead(Anchor3DHead): ...@@ -447,7 +464,7 @@ class ShapeAwareHead(Anchor3DHead):
in single batch. in single batch.
input_meta (list[dict]): Contain pcd and img's meta info. input_meta (list[dict]): Contain pcd and img's meta info.
cfg (:obj:`ConfigDict`): Training or testing config. cfg (:obj:`ConfigDict`): Training or testing config.
rescale (list[torch.Tensor], optional): whether to rescale bbox. rescale (list[torch.Tensor]): whether to rescale bbox.
Default: False. Default: False.
Returns: Returns:
...@@ -513,4 +530,8 @@ class ShapeAwareHead(Anchor3DHead): ...@@ -513,4 +530,8 @@ class ShapeAwareHead(Anchor3DHead):
dir_rot + self.dir_offset + dir_rot + self.dir_offset +
np.pi * dir_scores.to(bboxes.dtype)) np.pi * dir_scores.to(bboxes.dtype))
bboxes = input_meta['box_type_3d'](bboxes, box_dim=self.box_code_size) bboxes = input_meta['box_type_3d'](bboxes, box_dim=self.box_code_size)
return bboxes, scores, labels results = InstanceData()
results.bboxes_3d = bboxes
results.scores_3d = scores
results.labels_3d = labels
return results
import unittest
import torch
from mmengine import DefaultScope
from mmdet3d.core import LiDARInstance3DBoxes
from mmdet3d.registry import MODELS
from tests.utils.model_utils import (_create_detector_inputs,
_get_detector_cfg, _setup_seed)
class TestFreeAnchor(unittest.TestCase):
def test_freeanchor(self):
import mmdet3d.models
assert hasattr(mmdet3d.models.dense_heads, 'FreeAnchor3DHead')
DefaultScope.get_instance('test_freeanchor', scope_name='mmdet3d')
_setup_seed(0)
freeanchor_cfg = _get_detector_cfg(
'free_anchor/hv_pointpillars_fpn_sbn-all_free-'
'anchor_4x8_2x_nus-3d.py')
model = MODELS.build(freeanchor_cfg)
num_gt_instance = 50
data = [
_create_detector_inputs(
num_gt_instance=num_gt_instance, gt_bboxes_dim=9)
]
aug_data = [
_create_detector_inputs(
num_gt_instance=num_gt_instance, gt_bboxes_dim=9),
_create_detector_inputs(
num_gt_instance=num_gt_instance + 1, gt_bboxes_dim=9)
]
# test_aug_test
metainfo = {
'pcd_scale_factor': 1,
'pcd_horizontal_flip': 1,
'pcd_vertical_flip': 1,
'box_type_3d': LiDARInstance3DBoxes
}
for item in aug_data:
item['data_sample'].set_metainfo(metainfo)
if torch.cuda.is_available():
model = model.cuda()
# test simple_test
with torch.no_grad():
batch_inputs, data_samples = model.data_preprocessor(
data, True)
results = model.forward(
batch_inputs, data_samples, mode='predict')
self.assertEqual(len(results), len(data))
self.assertIn('bboxes_3d', results[0].pred_instances_3d)
self.assertIn('scores_3d', results[0].pred_instances_3d)
self.assertIn('labels_3d', results[0].pred_instances_3d)
batch_inputs, data_samples = model.data_preprocessor(
aug_data, True)
aug_results = model.forward(
batch_inputs, data_samples, mode='predict')
self.assertEqual(len(results), len(data))
self.assertIn('bboxes_3d', aug_results[0].pred_instances_3d)
self.assertIn('scores_3d', aug_results[0].pred_instances_3d)
self.assertIn('labels_3d', aug_results[0].pred_instances_3d)
self.assertIn('bboxes_3d', aug_results[1].pred_instances_3d)
self.assertIn('scores_3d', aug_results[1].pred_instances_3d)
self.assertIn('labels_3d', aug_results[1].pred_instances_3d)
losses = model.forward(batch_inputs, data_samples, mode='loss')
self.assertGreater(losses['positive_bag_loss'], 0)
self.assertGreater(losses['negative_bag_loss'], 0)
import unittest
import torch
from mmengine import DefaultScope
from mmdet3d.core import LiDARInstance3DBoxes
from mmdet3d.registry import MODELS
from tests.utils.model_utils import (_create_detector_inputs,
_get_detector_cfg, _setup_seed)
class TestSSN(unittest.TestCase):
def test_ssn(self):
import mmdet3d.models
assert hasattr(mmdet3d.models.dense_heads, 'ShapeAwareHead')
DefaultScope.get_instance('test_ssn', scope_name='mmdet3d')
_setup_seed(0)
ssn_cfg = _get_detector_cfg(
'ssn/hv_ssn_secfpn_sbn-all_2x16_2x_nus-3d.py')
model = MODELS.build(ssn_cfg)
num_gt_instance = 50
data = [
_create_detector_inputs(
num_gt_instance=num_gt_instance, gt_bboxes_dim=9)
]
aug_data = [
_create_detector_inputs(
num_gt_instance=num_gt_instance, gt_bboxes_dim=9),
_create_detector_inputs(
num_gt_instance=num_gt_instance + 1, gt_bboxes_dim=9)
]
# test_aug_test
metainfo = {
'pcd_scale_factor': 1,
'pcd_horizontal_flip': 1,
'pcd_vertical_flip': 1,
'box_type_3d': LiDARInstance3DBoxes
}
for item in aug_data:
item['data_sample'].set_metainfo(metainfo)
if torch.cuda.is_available():
model = model.cuda()
# test simple_test
with torch.no_grad():
batch_inputs, data_samples = model.data_preprocessor(
data, True)
results = model.forward(
batch_inputs, data_samples, mode='predict')
self.assertEqual(len(results), len(data))
self.assertIn('bboxes_3d', results[0].pred_instances_3d)
self.assertIn('scores_3d', results[0].pred_instances_3d)
self.assertIn('labels_3d', results[0].pred_instances_3d)
batch_inputs, data_samples = model.data_preprocessor(
aug_data, True)
aug_results = model.forward(
batch_inputs, data_samples, mode='predict')
self.assertEqual(len(results), len(data))
self.assertIn('bboxes_3d', aug_results[0].pred_instances_3d)
self.assertIn('scores_3d', aug_results[0].pred_instances_3d)
self.assertIn('labels_3d', aug_results[0].pred_instances_3d)
self.assertIn('bboxes_3d', aug_results[1].pred_instances_3d)
self.assertIn('scores_3d', aug_results[1].pred_instances_3d)
self.assertIn('labels_3d', aug_results[1].pred_instances_3d)
losses = model.forward(batch_inputs, data_samples, mode='loss')
self.assertGreater(losses['loss_cls'][0], 0)
self.assertGreater(losses['loss_bbox'][0], 0)
self.assertGreater(losses['loss_dir'][0], 0)
...@@ -76,6 +76,7 @@ def _create_detector_inputs(seed=0, ...@@ -76,6 +76,7 @@ def _create_detector_inputs(seed=0,
with_img=False, with_img=False,
num_gt_instance=20, num_gt_instance=20,
points_feat_dim=4, points_feat_dim=4,
gt_bboxes_dim=7,
num_classes=3): num_classes=3):
_setup_seed(seed) _setup_seed(seed)
inputs_dict = dict() inputs_dict = dict()
...@@ -88,7 +89,7 @@ def _create_detector_inputs(seed=0, ...@@ -88,7 +89,7 @@ def _create_detector_inputs(seed=0,
gt_instance_3d = InstanceData() gt_instance_3d = InstanceData()
gt_instance_3d.bboxes_3d = LiDARInstance3DBoxes( gt_instance_3d.bboxes_3d = LiDARInstance3DBoxes(
torch.rand([num_gt_instance, 7])) torch.rand([num_gt_instance, gt_bboxes_dim]), box_dim=gt_bboxes_dim)
gt_instance_3d.labels_3d = torch.randint(0, num_classes, [num_gt_instance]) gt_instance_3d.labels_3d = torch.randint(0, num_classes, [num_gt_instance])
data_sample = Det3DDataSample( data_sample = Det3DDataSample(
metainfo=dict(box_type_3d=LiDARInstance3DBoxes)) metainfo=dict(box_type_3d=LiDARInstance3DBoxes))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment