Unverified Commit b016d90c authored by twang's avatar twang Committed by GitHub
Browse files

[Feature] Add shape-aware grouping head in SSN (#147)

* Add shape-aware grouping head

* Reformat docstrings

* Remove rewritten get_anchors in shape_aware_head

* Refactor and simplify shape-aware grouping head

* Fix docstring

* Remove fixed preset shape heads

* Add unittest for AlignedAnchor3DRangeGeneratorPerCls

* Add unittest for get bboxes in shape_aware_head

* Add unittest for loss of shape_aware_head

* Fix unstandard docstrings

* Minor fix for a comment

* Add assertion to make sure not all boxes are filtered
parent cec2d8b0
from mmdet.core.anchor import build_anchor_generator from mmdet.core.anchor import build_anchor_generator
from .anchor_3d_generator import (AlignedAnchor3DRangeGenerator, from .anchor_3d_generator import (AlignedAnchor3DRangeGenerator,
AlignedAnchor3DRangeGeneratorPerCls,
Anchor3DRangeGenerator) Anchor3DRangeGenerator)
__all__ = [ __all__ = [
'AlignedAnchor3DRangeGenerator', 'Anchor3DRangeGenerator', 'AlignedAnchor3DRangeGenerator', 'Anchor3DRangeGenerator',
'build_anchor_generator' 'build_anchor_generator', 'AlignedAnchor3DRangeGeneratorPerCls'
] ]
...@@ -323,3 +323,81 @@ class AlignedAnchor3DRangeGenerator(Anchor3DRangeGenerator): ...@@ -323,3 +323,81 @@ class AlignedAnchor3DRangeGenerator(Anchor3DRangeGenerator):
# custom[:] = self.custom_values # custom[:] = self.custom_values
ret = torch.cat([ret, custom], dim=-1) ret = torch.cat([ret, custom], dim=-1)
return ret return ret
@ANCHOR_GENERATORS.register_module()
class AlignedAnchor3DRangeGeneratorPerCls(AlignedAnchor3DRangeGenerator):
"""3D Anchor Generator by range for per class.
This anchor generator generates anchors by the given range for per class.
Note that feature maps of different classes may be different.
Args:
kwargs (dict): Arguments are the same as those in \
:class:`AlignedAnchor3DRangeGenerator`.
"""
def __init__(self, **kwargs):
super(AlignedAnchor3DRangeGeneratorPerCls, self).__init__(**kwargs)
assert len(self.scales) == 1, 'Multi-scale feature map levels are' + \
' not supported currently in this kind of anchor generator.'
def grid_anchors(self, featmap_sizes, device='cuda'):
"""Generate grid anchors in multiple feature levels.
Args:
featmap_sizes (list[tuple]): List of feature map sizes for \
different classes in a single feature level.
device (str): Device where the anchors will be put on.
Returns:
list[list[torch.Tensor]]: Anchors in multiple feature levels. \
Note that in this anchor generator, we currently only \
support single feature level. The sizes of each tensor \
should be [num_sizes/ranges*num_rots*featmap_size, \
box_code_size].
"""
multi_level_anchors = []
anchors = self.multi_cls_grid_anchors(
featmap_sizes, self.scales[0], device=device)
multi_level_anchors.append(anchors)
return multi_level_anchors
def multi_cls_grid_anchors(self, featmap_sizes, scale, device='cuda'):
"""Generate grid anchors of a single level feature map for multi-class
with different feature map sizes.
This function is usually called by method ``self.grid_anchors``.
Args:
featmap_sizes (list[tuple]): List of feature map sizes for \
different classes in a single feature level.
scale (float): Scale factor of the anchors in the current level.
device (str, optional): Device the tensor will be put on.
Defaults to 'cuda'.
Returns:
torch.Tensor: Anchors in the overall feature map.
"""
assert len(featmap_sizes) == len(self.sizes) == len(self.ranges), \
'The number of different feature map sizes anchor sizes and ' + \
'ranges should be the same.'
multi_cls_anchors = []
for i in range(len(featmap_sizes)):
anchors = self.anchors_single_range(
featmap_sizes[i],
self.ranges[i],
scale,
self.sizes[i],
self.rotations,
device=device)
# [*featmap_size, num_sizes/ranges, num_rots, box_code_size]
ndim = len(featmap_sizes[i])
anchors = anchors.view(*featmap_sizes[i], -1, anchors.size(-1))
# [*featmap_size, num_sizes/ranges*num_rots, box_code_size]
anchors = anchors.permute(ndim, *range(0, ndim), ndim + 1)
# [num_sizes/ranges*num_rots, *featmap_size, box_code_size]
multi_cls_anchors.append(anchors.reshape(-1, anchors.size(-1)))
# [num_sizes/ranges*num_rots*featmap_size, box_code_size]
return multi_cls_anchors
...@@ -3,10 +3,11 @@ from .base_conv_bbox_head import BaseConvBboxHead ...@@ -3,10 +3,11 @@ from .base_conv_bbox_head import BaseConvBboxHead
from .centerpoint_head import CenterHead from .centerpoint_head import CenterHead
from .free_anchor3d_head import FreeAnchor3DHead from .free_anchor3d_head import FreeAnchor3DHead
from .parta2_rpn_head import PartA2RPNHead from .parta2_rpn_head import PartA2RPNHead
from .shape_aware_head import ShapeAwareHead
from .ssd_3d_head import SSD3DHead from .ssd_3d_head import SSD3DHead
from .vote_head import VoteHead from .vote_head import VoteHead
__all__ = [ __all__ = [
'Anchor3DHead', 'FreeAnchor3DHead', 'PartA2RPNHead', 'VoteHead', 'Anchor3DHead', 'FreeAnchor3DHead', 'PartA2RPNHead', 'VoteHead',
'SSD3DHead', 'BaseConvBboxHead', 'CenterHead' 'SSD3DHead', 'BaseConvBboxHead', 'CenterHead', 'ShapeAwareHead'
] ]
import numpy as np
import torch
from mmcv.cnn import ConvModule, bias_init_with_prob, normal_init
from torch import nn as nn
from mmdet3d.core import box3d_multiclass_nms, limit_period, xywhr2xyxyr
from mmdet.core import multi_apply
from mmdet.models import HEADS
from ..builder import build_head
from .anchor3d_head import Anchor3DHead
@HEADS.register_module()
class BaseShapeHead(nn.Module):
"""Base Shape-aware Head in Shape Signature Network.
Note:
This base shape-aware grouping head uses default settings for small
objects. For large and huge objects, it is recommended to use
heavier heads, like (64, 64, 64) and (128, 128, 64, 64, 64) in
shared conv channels, (2, 1, 1) and (2, 1, 2, 1, 1) in shared
conv strides. For tiny objects, we can use smaller heads, like
(32, 32) channels and (1, 1) strides.
Args:
num_cls (int): Number of classes.
num_base_anchors (int): Number of anchors per location.
box_code_size (int): The dimension of boxes to be encoded.
in_channels (int): Input channels for convolutional layers.
shared_conv_channels (tuple): Channels for shared convolutional \
layers. Default: (64, 64). \
shared_conv_strides (tuple): Strides for shared convolutional \
layers. Default: (1, 1).
use_direction_classifier (bool, optional): Whether to use direction \
classifier. Default: True.
conv_cfg (dict): Config of conv layer. Default: dict(type='Conv2d')
norm_cfg (dict): Config of norm layer. Default: dict(type='BN2d').
bias (bool|str, optional): Type of bias. Default: False.
"""
def __init__(self,
num_cls,
num_base_anchors,
box_code_size,
in_channels,
shared_conv_channels=(64, 64),
shared_conv_strides=(1, 1),
use_direction_classifier=True,
conv_cfg=dict(type='Conv2d'),
norm_cfg=dict(type='BN2d'),
bias=False):
super().__init__()
self.num_cls = num_cls
self.num_base_anchors = num_base_anchors
self.use_direction_classifier = use_direction_classifier
self.box_code_size = box_code_size
assert len(shared_conv_channels) == len(shared_conv_strides), \
'Lengths of channels and strides list should be equal.'
self.shared_conv_channels = [in_channels] + list(shared_conv_channels)
self.shared_conv_strides = list(shared_conv_strides)
shared_conv = []
for i in range(len(self.shared_conv_strides)):
shared_conv.append(
ConvModule(
self.shared_conv_channels[i],
self.shared_conv_channels[i + 1],
kernel_size=3,
stride=self.shared_conv_strides[i],
padding=1,
conv_cfg=conv_cfg,
bias=bias,
norm_cfg=norm_cfg))
self.shared_conv = nn.Sequential(*shared_conv)
out_channels = self.shared_conv_channels[-1]
self.conv_cls = nn.Conv2d(out_channels, num_base_anchors * num_cls, 1)
self.conv_reg = nn.Conv2d(out_channels,
num_base_anchors * box_code_size, 1)
if use_direction_classifier:
self.conv_dir_cls = nn.Conv2d(out_channels, num_base_anchors * 2,
1)
def init_weights(self):
"""Initialize weights."""
bias_cls = bias_init_with_prob(0.01)
# shared conv layers have already been initialized by ConvModule
normal_init(self.conv_cls, std=0.01, bias=bias_cls)
normal_init(self.conv_reg, std=0.01)
if self.use_direction_classifier:
normal_init(self.conv_dir_cls, std=0.01, bias=bias_cls)
def forward(self, x):
"""Forward function for SmallHead.
Args:
x (torch.Tensor): Input feature map with the shape of
[B, C, H, W].
Returns:
dict[torch.Tensor]: Contain score of each class, bbox \
regression and direction classification predictions. \
Note that all the returned tensors are reshaped as \
[bs*num_base_anchors*H*W, num_cls/box_code_size/dir_bins]. \
It is more convenient to concat anchors for different \
classes even though they have different feature map sizes.
"""
x = self.shared_conv(x)
cls_score = self.conv_cls(x)
bbox_pred = self.conv_reg(x)
featmap_size = bbox_pred.shape[-2:]
H, W = featmap_size
B = bbox_pred.shape[0]
cls_score = cls_score.view(-1, self.num_base_anchors, self.num_cls, H,
W).permute(0, 1, 3, 4,
2).reshape(B, -1, self.num_cls)
bbox_pred = bbox_pred.view(-1, self.num_base_anchors,
self.box_code_size, H, W).permute(
0, 1, 3, 4,
2).reshape(B, -1, self.box_code_size)
dir_cls_preds = None
if self.use_direction_classifier:
dir_cls_preds = self.conv_dir_cls(x)
dir_cls_preds = dir_cls_preds.view(-1, self.num_base_anchors, 2, H,
W).permute(0, 1, 3, 4,
2).reshape(B, -1, 2)
ret = dict(
cls_score=cls_score,
bbox_pred=bbox_pred,
dir_cls_preds=dir_cls_preds,
featmap_size=featmap_size)
return ret
@HEADS.register_module()
class ShapeAwareHead(Anchor3DHead):
"""Shape-aware grouping head for SSN.
Args:
tasks (dict): Shape-aware groups of multi-class objects.
assign_per_class (bool, optional): Whether to do assignment for each \
class. Default: True.
kwargs (dict): Other arguments are the same as those in \
:class:`Anchor3DHead`.
"""
def __init__(self, tasks, assign_per_class=True, **kwargs):
self.tasks = tasks
self.featmap_sizes = []
super().__init__(assign_per_class=assign_per_class, **kwargs)
def _init_layers(self):
"""Initialize neural network layers of the head."""
self.heads = nn.ModuleList()
cls_ptr = 0
for task in self.tasks:
sizes = self.anchor_generator.sizes[cls_ptr:cls_ptr +
task['num_class']]
num_size = torch.tensor(sizes).reshape(-1, 3).size(0)
num_rot = len(self.anchor_generator.rotations)
num_base_anchors = num_rot * num_size
branch = dict(
type='BaseShapeHead',
num_cls=self.num_classes,
num_base_anchors=num_base_anchors,
box_code_size=self.box_code_size,
in_channels=self.in_channels,
shared_conv_channels=task['shared_conv_channels'],
shared_conv_strides=task['shared_conv_strides'])
self.heads.append(build_head(branch))
cls_ptr += task['num_class']
def init_weights(self):
"""Initialize the weights of head."""
for head in self.heads:
head.init_weights()
def forward_single(self, x):
"""Forward function on a single-scale feature map.
Args:
x (torch.Tensor): Input features.
Returns:
tuple[torch.Tensor]: Contain score of each class, bbox \
regression and direction classification predictions.
"""
results = []
for head in self.heads:
results.append(head(x))
cls_score = torch.cat([result['cls_score'] for result in results],
dim=1)
bbox_pred = torch.cat([result['bbox_pred'] for result in results],
dim=1)
dir_cls_preds = None
if self.use_direction_classifier:
dir_cls_preds = torch.cat(
[result['dir_cls_preds'] for result in results], dim=1)
self.featmap_sizes = []
for i, task in enumerate(self.tasks):
for _ in range(task['num_class']):
self.featmap_sizes.append(results[i]['featmap_size'])
assert len(self.featmap_sizes) == len(self.anchor_generator.ranges), \
'Length of feature map sizes must be equal to length of ' + \
'different ranges of anchor generator.'
return cls_score, bbox_pred, dir_cls_preds
def loss_single(self, cls_score, bbox_pred, dir_cls_preds, labels,
label_weights, bbox_targets, bbox_weights, dir_targets,
dir_weights, num_total_samples):
"""Calculate loss of Single-level results.
Args:
cls_score (torch.Tensor): Class score in single-level.
bbox_pred (torch.Tensor): Bbox prediction in single-level.
dir_cls_preds (torch.Tensor): Predictions of direction class
in single-level.
labels (torch.Tensor): Labels of class.
label_weights (torch.Tensor): Weights of class loss.
bbox_targets (torch.Tensor): Targets of bbox predictions.
bbox_weights (torch.Tensor): Weights of bbox loss.
dir_targets (torch.Tensor): Targets of direction predictions.
dir_weights (torch.Tensor): Weights of direction loss.
num_total_samples (int): The number of valid samples.
Returns:
tuple[torch.Tensor]: Losses of class, bbox \
and direction, respectively.
"""
# classification loss
if num_total_samples is None:
num_total_samples = int(cls_score.shape[0])
labels = labels.reshape(-1)
label_weights = label_weights.reshape(-1)
cls_score = cls_score.reshape(-1, self.num_classes)
loss_cls = self.loss_cls(
cls_score, labels, label_weights, avg_factor=num_total_samples)
# regression loss
bbox_targets = bbox_targets.reshape(-1, self.box_code_size)
bbox_weights = bbox_weights.reshape(-1, self.box_code_size)
code_weight = self.train_cfg.get('code_weight', None)
if code_weight:
bbox_weights = bbox_weights * bbox_weights.new_tensor(code_weight)
bbox_pred = bbox_pred.reshape(-1, self.box_code_size)
if self.diff_rad_by_sin:
bbox_pred, bbox_targets = self.add_sin_difference(
bbox_pred, bbox_targets)
loss_bbox = self.loss_bbox(
bbox_pred,
bbox_targets,
bbox_weights,
avg_factor=num_total_samples)
# direction classification loss
loss_dir = None
if self.use_direction_classifier:
dir_cls_preds = dir_cls_preds.reshape(-1, 2)
dir_targets = dir_targets.reshape(-1)
dir_weights = dir_weights.reshape(-1)
loss_dir = self.loss_dir(
dir_cls_preds,
dir_targets,
dir_weights,
avg_factor=num_total_samples)
return loss_cls, loss_bbox, loss_dir
def loss(self,
cls_scores,
bbox_preds,
dir_cls_preds,
gt_bboxes,
gt_labels,
input_metas,
gt_bboxes_ignore=None):
"""Calculate losses.
Args:
cls_scores (list[torch.Tensor]): Multi-level class scores.
bbox_preds (list[torch.Tensor]): Multi-level bbox predictions.
dir_cls_preds (list[torch.Tensor]): Multi-level direction
class predictions.
gt_bboxes (list[:obj:`BaseInstance3DBoxes`]): Gt bboxes
of each sample.
gt_labels (list[torch.Tensor]): Gt labels of each sample.
input_metas (list[dict]): Contain pcd and img's meta info.
gt_bboxes_ignore (None | list[torch.Tensor]): Specify
which bounding.
Returns:
dict[str, list[torch.Tensor]]: Classification, bbox, and \
direction losses of each level.
- loss_cls (list[torch.Tensor]): Classification losses.
- loss_bbox (list[torch.Tensor]): Box regression losses.
- loss_dir (list[torch.Tensor]): Direction classification \
losses.
"""
device = cls_scores[0].device
anchor_list = self.get_anchors(
self.featmap_sizes, input_metas, device=device)
cls_reg_targets = self.anchor_target_3d(
anchor_list,
gt_bboxes,
input_metas,
gt_bboxes_ignore_list=gt_bboxes_ignore,
gt_labels_list=gt_labels,
num_classes=self.num_classes,
sampling=self.sampling)
if cls_reg_targets is None:
return None
(labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
dir_targets_list, dir_weights_list, num_total_pos,
num_total_neg) = cls_reg_targets
num_total_samples = (
num_total_pos + num_total_neg if self.sampling else num_total_pos)
# num_total_samples = None
losses_cls, losses_bbox, losses_dir = multi_apply(
self.loss_single,
cls_scores,
bbox_preds,
dir_cls_preds,
labels_list,
label_weights_list,
bbox_targets_list,
bbox_weights_list,
dir_targets_list,
dir_weights_list,
num_total_samples=num_total_samples)
return dict(
loss_cls=losses_cls, loss_bbox=losses_bbox, loss_dir=losses_dir)
def get_bboxes(self,
cls_scores,
bbox_preds,
dir_cls_preds,
input_metas,
cfg=None,
rescale=False):
"""Get bboxes of anchor head.
Args:
cls_scores (list[torch.Tensor]): Multi-level class scores.
bbox_preds (list[torch.Tensor]): Multi-level bbox predictions.
dir_cls_preds (list[torch.Tensor]): Multi-level direction
class predictions.
input_metas (list[dict]): Contain pcd and img's meta info.
cfg (None | :obj:`ConfigDict`): Training or testing config.
Default: None.
rescale (list[torch.Tensor], optional): Whether to rescale bbox.
Default: False.
Returns:
list[tuple]: Prediction resultes of batches.
"""
assert len(cls_scores) == len(bbox_preds)
assert len(cls_scores) == len(dir_cls_preds)
num_levels = len(cls_scores)
assert num_levels == 1, 'Only support single level inference.'
device = cls_scores[0].device
mlvl_anchors = self.anchor_generator.grid_anchors(
self.featmap_sizes, device=device)
# `anchor` is a list of anchors for different classes
mlvl_anchors = [torch.cat(anchor, dim=0) for anchor in mlvl_anchors]
result_list = []
for img_id in range(len(input_metas)):
cls_score_list = [
cls_scores[i][img_id].detach() for i in range(num_levels)
]
bbox_pred_list = [
bbox_preds[i][img_id].detach() for i in range(num_levels)
]
dir_cls_pred_list = [
dir_cls_preds[i][img_id].detach() for i in range(num_levels)
]
input_meta = input_metas[img_id]
proposals = self.get_bboxes_single(cls_score_list, bbox_pred_list,
dir_cls_pred_list, mlvl_anchors,
input_meta, cfg, rescale)
result_list.append(proposals)
return result_list
def get_bboxes_single(self,
cls_scores,
bbox_preds,
dir_cls_preds,
mlvl_anchors,
input_meta,
cfg=None,
rescale=False):
"""Get bboxes of single branch.
Args:
cls_scores (torch.Tensor): Class score in single batch.
bbox_preds (torch.Tensor): Bbox prediction in single batch.
dir_cls_preds (torch.Tensor): Predictions of direction class
in single batch.
mlvl_anchors (List[torch.Tensor]): Multi-level anchors
in single batch.
input_meta (list[dict]): Contain pcd and img's meta info.
cfg (None | :obj:`ConfigDict`): Training or testing config.
rescale (list[torch.Tensor], optional): whether to rescale bbox. \
Default: False.
Returns:
tuple: Contain predictions of single batch.
- bboxes (:obj:`BaseInstance3DBoxes`): Predicted 3d bboxes.
- scores (torch.Tensor): Class score of each bbox.
- labels (torch.Tensor): Label of each bbox.
"""
cfg = self.test_cfg if cfg is None else cfg
assert len(cls_scores) == len(bbox_preds) == len(mlvl_anchors)
mlvl_bboxes = []
mlvl_scores = []
mlvl_dir_scores = []
for cls_score, bbox_pred, dir_cls_pred, anchors in zip(
cls_scores, bbox_preds, dir_cls_preds, mlvl_anchors):
assert cls_score.size()[-2] == bbox_pred.size()[-2]
assert cls_score.size()[-2] == dir_cls_pred.size()[-2]
dir_cls_score = torch.max(dir_cls_pred, dim=-1)[1]
if self.use_sigmoid_cls:
scores = cls_score.sigmoid()
else:
scores = cls_score.softmax(-1)
nms_pre = cfg.get('nms_pre', -1)
if nms_pre > 0 and scores.shape[0] > nms_pre:
if self.use_sigmoid_cls:
max_scores, _ = scores.max(dim=1)
else:
max_scores, _ = scores[:, :-1].max(dim=1)
_, topk_inds = max_scores.topk(nms_pre)
anchors = anchors[topk_inds, :]
bbox_pred = bbox_pred[topk_inds, :]
scores = scores[topk_inds, :]
dir_cls_score = dir_cls_score[topk_inds]
bboxes = self.bbox_coder.decode(anchors, bbox_pred)
mlvl_bboxes.append(bboxes)
mlvl_scores.append(scores)
mlvl_dir_scores.append(dir_cls_score)
mlvl_bboxes = torch.cat(mlvl_bboxes)
mlvl_bboxes_for_nms = xywhr2xyxyr(input_meta['box_type_3d'](
mlvl_bboxes, box_dim=self.box_code_size).bev)
mlvl_scores = torch.cat(mlvl_scores)
mlvl_dir_scores = torch.cat(mlvl_dir_scores)
if self.use_sigmoid_cls:
# Add a dummy background class to the front when using sigmoid
padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
mlvl_scores = torch.cat([mlvl_scores, padding], dim=1)
score_thr = cfg.get('score_thr', 0)
results = box3d_multiclass_nms(mlvl_bboxes, mlvl_bboxes_for_nms,
mlvl_scores, score_thr, cfg.max_num,
cfg, mlvl_dir_scores)
bboxes, scores, labels, dir_scores = results
if bboxes.shape[0] > 0:
dir_rot = limit_period(bboxes[..., 6] - self.dir_offset,
self.dir_limit_offset, np.pi)
bboxes[..., 6] = (
dir_rot + self.dir_offset +
np.pi * dir_scores.to(bboxes.dtype))
bboxes = input_meta['box_type_3d'](bboxes, box_dim=self.box_code_size)
return bboxes, scores, labels
...@@ -40,6 +40,16 @@ class AnchorTrainMixin(object): ...@@ -40,6 +40,16 @@ class AnchorTrainMixin(object):
num_imgs = len(input_metas) num_imgs = len(input_metas)
assert len(anchor_list) == num_imgs assert len(anchor_list) == num_imgs
if isinstance(anchor_list[0][0], list):
# sizes of anchors are different
# anchor number of a single level
num_level_anchors = [
sum([anchor.size(0) for anchor in anchors])
for anchors in anchor_list[0]
]
for i in range(num_imgs):
anchor_list[i] = anchor_list[i][0]
else:
# anchor number of multi levels # anchor number of multi levels
num_level_anchors = [ num_level_anchors = [
anchors.view(-1, self.box_code_size).size(0) anchors.view(-1, self.box_code_size).size(0)
...@@ -112,7 +122,8 @@ class AnchorTrainMixin(object): ...@@ -112,7 +122,8 @@ class AnchorTrainMixin(object):
Returns: Returns:
tuple[torch.Tensor]: Anchor targets. tuple[torch.Tensor]: Anchor targets.
""" """
if isinstance(self.bbox_assigner, list): if isinstance(self.bbox_assigner,
list) and (not isinstance(anchors, list)):
feat_size = anchors.size(0) * anchors.size(1) * anchors.size(2) feat_size = anchors.size(0) * anchors.size(1) * anchors.size(2)
rot_angles = anchors.size(-2) rot_angles = anchors.size(-2)
assert len(self.bbox_assigner) == anchors.size(-3) assert len(self.bbox_assigner) == anchors.size(-3)
...@@ -129,12 +140,11 @@ class AnchorTrainMixin(object): ...@@ -129,12 +140,11 @@ class AnchorTrainMixin(object):
anchor_targets = self.anchor_target_single_assigner( anchor_targets = self.anchor_target_single_assigner(
assigner, current_anchors, gt_bboxes[gt_per_cls, :], assigner, current_anchors, gt_bboxes[gt_per_cls, :],
gt_bboxes_ignore, gt_labels[gt_per_cls], input_meta, gt_bboxes_ignore, gt_labels[gt_per_cls], input_meta,
label_channels, num_classes, sampling) num_classes, sampling)
else: else:
anchor_targets = self.anchor_target_single_assigner( anchor_targets = self.anchor_target_single_assigner(
assigner, current_anchors, gt_bboxes, gt_bboxes_ignore, assigner, current_anchors, gt_bboxes, gt_bboxes_ignore,
gt_labels, input_meta, label_channels, num_classes, gt_labels, input_meta, num_classes, sampling)
sampling)
(labels, label_weights, bbox_targets, bbox_weights, (labels, label_weights, bbox_targets, bbox_weights,
dir_targets, dir_weights, pos_inds, neg_inds) = anchor_targets dir_targets, dir_weights, pos_inds, neg_inds) = anchor_targets
...@@ -170,10 +180,59 @@ class AnchorTrainMixin(object): ...@@ -170,10 +180,59 @@ class AnchorTrainMixin(object):
return (total_labels, total_label_weights, total_bbox_targets, return (total_labels, total_label_weights, total_bbox_targets,
total_bbox_weights, total_dir_targets, total_dir_weights, total_bbox_weights, total_dir_targets, total_dir_weights,
total_pos_inds, total_neg_inds) total_pos_inds, total_neg_inds)
elif isinstance(self.bbox_assigner, list) and isinstance(
anchors, list):
# class-aware anchors with different feature map sizes
assert len(self.bbox_assigner) == len(anchors), \
'The number of bbox assigners and anchors should be the same.'
(total_labels, total_label_weights, total_bbox_targets,
total_bbox_weights, total_dir_targets, total_dir_weights,
total_pos_inds, total_neg_inds) = [], [], [], [], [], [], [], []
current_anchor_num = 0
for i, assigner in enumerate(self.bbox_assigner):
current_anchors = anchors[i]
current_anchor_num += current_anchors.size(0)
if self.assign_per_class:
gt_per_cls = (gt_labels == i)
anchor_targets = self.anchor_target_single_assigner(
assigner, current_anchors, gt_bboxes[gt_per_cls, :],
gt_bboxes_ignore, gt_labels[gt_per_cls], input_meta,
num_classes, sampling)
else: else:
return self.anchor_target_single_assigner( anchor_targets = self.anchor_target_single_assigner(
self.bbox_assigner, anchors, gt_bboxes, gt_bboxes_ignore, assigner, current_anchors, gt_bboxes, gt_bboxes_ignore,
gt_labels, input_meta, label_channels, num_classes, sampling) gt_labels, input_meta, num_classes, sampling)
(labels, label_weights, bbox_targets, bbox_weights,
dir_targets, dir_weights, pos_inds, neg_inds) = anchor_targets
total_labels.append(labels)
total_label_weights.append(label_weights)
total_bbox_targets.append(
bbox_targets.reshape(-1, anchors[i].size(-1)))
total_bbox_weights.append(
bbox_weights.reshape(-1, anchors[i].size(-1)))
total_dir_targets.append(dir_targets)
total_dir_weights.append(dir_weights)
total_pos_inds.append(pos_inds)
total_neg_inds.append(neg_inds)
total_labels = torch.cat(total_labels, dim=0)
total_label_weights = torch.cat(total_label_weights, dim=0)
total_bbox_targets = torch.cat(total_bbox_targets, dim=0)
total_bbox_weights = torch.cat(total_bbox_weights, dim=0)
total_dir_targets = torch.cat(total_dir_targets, dim=0)
total_dir_weights = torch.cat(total_dir_weights, dim=0)
total_pos_inds = torch.cat(total_pos_inds, dim=0)
total_neg_inds = torch.cat(total_neg_inds, dim=0)
return (total_labels, total_label_weights, total_bbox_targets,
total_bbox_weights, total_dir_targets, total_dir_weights,
total_pos_inds, total_neg_inds)
else:
return self.anchor_target_single_assigner(self.bbox_assigner,
anchors, gt_bboxes,
gt_bboxes_ignore,
gt_labels, input_meta,
num_classes, sampling)
def anchor_target_single_assigner(self, def anchor_target_single_assigner(self,
bbox_assigner, bbox_assigner,
...@@ -182,7 +241,6 @@ class AnchorTrainMixin(object): ...@@ -182,7 +241,6 @@ class AnchorTrainMixin(object):
gt_bboxes_ignore, gt_bboxes_ignore,
gt_labels, gt_labels,
input_meta, input_meta,
label_channels=1,
num_classes=1, num_classes=1,
sampling=True): sampling=True):
"""Assign anchors and encode positive anchors. """Assign anchors and encode positive anchors.
...@@ -194,7 +252,6 @@ class AnchorTrainMixin(object): ...@@ -194,7 +252,6 @@ class AnchorTrainMixin(object):
gt_bboxes_ignore (torch.Tensor): Ignored gt bboxes. gt_bboxes_ignore (torch.Tensor): Ignored gt bboxes.
gt_labels (torch.Tensor): Gt class labels. gt_labels (torch.Tensor): Gt class labels.
input_meta (dict): Meta info of each image. input_meta (dict): Meta info of each image.
label_channels (int): The channel of labels.
num_classes (int): The number of classes. num_classes (int): The number of classes.
sampling (bool): Whether to sample anchors. sampling (bool): Whether to sample anchors.
......
...@@ -181,3 +181,58 @@ def test_aligned_anchor_generator(): ...@@ -181,3 +181,58 @@ def test_aligned_anchor_generator():
# set [:56:7] thus it could cover 8 (len(size) * len(rotations)) # set [:56:7] thus it could cover 8 (len(size) * len(rotations))
# anchors on 8 location # anchors on 8 location
assert single_level_anchor[:56:7].allclose(expected_grid_anchors[i]) assert single_level_anchor[:56:7].allclose(expected_grid_anchors[i])
def test_aligned_anchor_generator_per_cls():
if torch.cuda.is_available():
device = 'cuda'
else:
device = 'cpu'
anchor_generator_cfg = dict(
type='AlignedAnchor3DRangeGeneratorPerCls',
ranges=[[-100, -100, -1.80, 100, 100, -1.80],
[-100, -100, -1.30, 100, 100, -1.30]],
sizes=[[0.63, 1.76, 1.44], [0.96, 2.35, 1.59]],
custom_values=[0, 0],
rotations=[0, 1.57],
reshape_out=False)
featmap_sizes = [(100, 100), (50, 50)]
anchor_generator = build_anchor_generator(anchor_generator_cfg)
# check base anchors
expected_grid_anchors = [[
torch.tensor([[
-99.0000, -99.0000, -1.8000, 0.6300, 1.7600, 1.4400, 0.0000,
0.0000, 0.0000
],
[
-99.0000, -99.0000, -1.8000, 0.6300, 1.7600, 1.4400,
1.5700, 0.0000, 0.0000
]],
device=device),
torch.tensor([[
-98.0000, -98.0000, -1.3000, 0.9600, 2.3500, 1.5900, 0.0000,
0.0000, 0.0000
],
[
-98.0000, -98.0000, -1.3000, 0.9600, 2.3500, 1.5900,
1.5700, 0.0000, 0.0000
]],
device=device)
]]
multi_level_anchors = anchor_generator.grid_anchors(
featmap_sizes, device=device)
expected_multi_level_shapes = [[
torch.Size([20000, 9]), torch.Size([5000, 9])
]]
for i, single_level_anchor in enumerate(multi_level_anchors):
assert len(single_level_anchor) == len(expected_multi_level_shapes[i])
# set [:2*interval:interval] thus it could cover
# 2 (len(size) * len(rotations)) anchors on 2 location
# Note that len(size) for each class is always 1 in this case
for j in range(len(single_level_anchor)):
interval = int(expected_multi_level_shapes[i][j][0] / 2)
assert single_level_anchor[j][:2 * interval:interval].allclose(
expected_grid_anchors[i][j])
...@@ -945,3 +945,102 @@ def test_ssd3d_head(): ...@@ -945,3 +945,102 @@ def test_ssd3d_head():
assert results[0][0].tensor.shape[1] == 7 assert results[0][0].tensor.shape[1] == 7
assert results[0][1].shape[0] >= 0 assert results[0][1].shape[0] >= 0
assert results[0][2].shape[0] >= 0 assert results[0][2].shape[0] >= 0
def test_shape_aware_head_loss():
if not torch.cuda.is_available():
pytest.skip('test requires GPU and torch+cuda')
bbox_head_cfg = _get_pts_bbox_head_cfg(
'ssn/hv_ssn_secfpn_sbn-all_2x16_2x_lyft-3d.py')
# modify bn config to avoid bugs caused by syncbn
for task in bbox_head_cfg['tasks']:
task['norm_cfg'] = dict(type='BN2d')
from mmdet3d.models.builder import build_head
self = build_head(bbox_head_cfg)
self.cuda()
assert len(self.heads) == 4
assert isinstance(self.heads[0].conv_cls, torch.nn.modules.conv.Conv2d)
assert self.heads[0].conv_cls.in_channels == 64
assert self.heads[0].conv_cls.out_channels == 36
assert self.heads[0].conv_reg.out_channels == 28
assert self.heads[0].conv_dir_cls.out_channels == 8
# test forward
feats = list()
feats.append(torch.rand([2, 384, 200, 200], dtype=torch.float32).cuda())
(cls_score, bbox_pred, dir_cls_preds) = self.forward(feats)
assert cls_score[0].shape == torch.Size([2, 420000, 9])
assert bbox_pred[0].shape == torch.Size([2, 420000, 7])
assert dir_cls_preds[0].shape == torch.Size([2, 420000, 2])
# test loss
gt_bboxes = [
LiDARInstance3DBoxes(
torch.tensor(
[[-14.5695, -6.4169, -2.1054, 1.8830, 4.6720, 1.4840, 1.5587],
[25.7215, 3.4581, -1.3456, 1.6720, 4.4090, 1.5830, 1.5301]],
dtype=torch.float32).cuda()),
LiDARInstance3DBoxes(
torch.tensor(
[[-50.763, -3.5517, -0.99658, 1.7430, 4.4020, 1.6990, 1.7874],
[-68.720, 0.033, -0.75276, 1.7860, 4.9100, 1.6610, 1.7525]],
dtype=torch.float32).cuda())
]
gt_labels = list(torch.tensor([[4, 4], [4, 4]], dtype=torch.int64).cuda())
input_metas = [{
'sample_idx': 1234
}, {
'sample_idx': 2345
}] # fake input_metas
losses = self.loss(cls_score, bbox_pred, dir_cls_preds, gt_bboxes,
gt_labels, input_metas)
assert losses['loss_cls'][0] > 0
assert losses['loss_bbox'][0] > 0
assert losses['loss_dir'][0] > 0
# test empty ground truth case
gt_bboxes = list(torch.empty((2, 0, 7)).cuda())
gt_labels = list(torch.empty((2, 0)).cuda())
empty_gt_losses = self.loss(cls_score, bbox_pred, dir_cls_preds, gt_bboxes,
gt_labels, input_metas)
assert empty_gt_losses['loss_cls'][0] > 0
assert empty_gt_losses['loss_bbox'][0] == 0
assert empty_gt_losses['loss_dir'][0] == 0
def test_shape_aware_head_getboxes():
if not torch.cuda.is_available():
pytest.skip('test requires GPU and torch+cuda')
bbox_head_cfg = _get_pts_bbox_head_cfg(
'ssn/hv_ssn_secfpn_sbn-all_2x16_2x_lyft-3d.py')
# modify bn config to avoid bugs caused by syncbn
for task in bbox_head_cfg['tasks']:
task['norm_cfg'] = dict(type='BN2d')
from mmdet3d.models.builder import build_head
self = build_head(bbox_head_cfg)
self.cuda()
feats = list()
feats.append(torch.rand([2, 384, 200, 200], dtype=torch.float32).cuda())
# fake input_metas
input_metas = [{
'sample_idx': 1234,
'box_type_3d': LiDARInstance3DBoxes,
'box_mode_3d': Box3DMode.LIDAR
}, {
'sample_idx': 2345,
'box_type_3d': LiDARInstance3DBoxes,
'box_mode_3d': Box3DMode.LIDAR
}]
(cls_score, bbox_pred, dir_cls_preds) = self.forward(feats)
# test get_bboxes
cls_score[0] -= 1.5 # too many positive samples may cause cuda oom
result_list = self.get_bboxes(cls_score, bbox_pred, dir_cls_preds,
input_metas)
assert len(result_list[0][1]) > 0 # ensure not all boxes are filtered
assert (result_list[0][1] > 0.3).all()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment