Commit 52310ad9 authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

support CenterHead / CenterPoint (1stage), add its WOD config

parent 0ccbbaae
...@@ -4,6 +4,7 @@ from .anchor_head_template import AnchorHeadTemplate ...@@ -4,6 +4,7 @@ from .anchor_head_template import AnchorHeadTemplate
from .point_head_box import PointHeadBox from .point_head_box import PointHeadBox
from .point_head_simple import PointHeadSimple from .point_head_simple import PointHeadSimple
from .point_intra_part_head import PointIntraPartOffsetHead from .point_intra_part_head import PointIntraPartOffsetHead
from .center_head import CenterHead
__all__ = { __all__ = {
'AnchorHeadTemplate': AnchorHeadTemplate, 'AnchorHeadTemplate': AnchorHeadTemplate,
...@@ -12,4 +13,5 @@ __all__ = { ...@@ -12,4 +13,5 @@ __all__ = {
'PointHeadSimple': PointHeadSimple, 'PointHeadSimple': PointHeadSimple,
'PointHeadBox': PointHeadBox, 'PointHeadBox': PointHeadBox,
'AnchorHeadMulti': AnchorHeadMulti, 'AnchorHeadMulti': AnchorHeadMulti,
'CenterHead': CenterHead
} }
...@@ -7,6 +7,7 @@ from .second_net import SECONDNet ...@@ -7,6 +7,7 @@ from .second_net import SECONDNet
from .second_net_iou import SECONDNetIoU from .second_net_iou import SECONDNetIoU
from .caddn import CaDDN from .caddn import CaDDN
from .voxel_rcnn import VoxelRCNN from .voxel_rcnn import VoxelRCNN
from .centerpoint import CenterPoint
__all__ = { __all__ = {
'Detector3DTemplate': Detector3DTemplate, 'Detector3DTemplate': Detector3DTemplate,
...@@ -17,7 +18,8 @@ __all__ = { ...@@ -17,7 +18,8 @@ __all__ = {
'PointRCNN': PointRCNN, 'PointRCNN': PointRCNN,
'SECONDNetIoU': SECONDNetIoU, 'SECONDNetIoU': SECONDNetIoU,
'CaDDN': CaDDN, 'CaDDN': CaDDN,
'VoxelRCNN': VoxelRCNN 'VoxelRCNN': VoxelRCNN,
'CenterPoint': CenterPoint
} }
......
...@@ -132,7 +132,8 @@ class Detector3DTemplate(nn.Module): ...@@ -132,7 +132,8 @@ class Detector3DTemplate(nn.Module):
class_names=self.class_names, class_names=self.class_names,
grid_size=model_info_dict['grid_size'], grid_size=model_info_dict['grid_size'],
point_cloud_range=model_info_dict['point_cloud_range'], point_cloud_range=model_info_dict['point_cloud_range'],
predict_boxes_when_training=self.model_cfg.get('ROI_HEAD', False) predict_boxes_when_training=self.model_cfg.get('ROI_HEAD', False),
voxel_size=model_info_dict.get('voxel_size', False)
) )
model_info_dict['module_list'].append(dense_head_module) model_info_dict['module_list'].append(dense_head_module)
return dense_head_module, model_info_dict return dense_head_module, model_info_dict
......
...@@ -259,3 +259,128 @@ def compute_fg_mask(gt_boxes2d, shape, downsample_factor=1, device=torch.device( ...@@ -259,3 +259,128 @@ def compute_fg_mask(gt_boxes2d, shape, downsample_factor=1, device=torch.device(
fg_mask[b, v1:v2, u1:u2] = True fg_mask[b, v1:v2, u1:u2] = True
return fg_mask return fg_mask
def neg_loss_cornernet(pred, gt, mask=None):
"""
Refer to https://github.com/tianweiy/CenterPoint.
Modified focal loss. Exactly the same as CornerNet. Runs faster and costs a little bit more memory
Args:
pred: (batch x c x h x w)
gt: (batch x c x h x w)
mask: (batch x h x w)
Returns:
"""
pos_inds = gt.eq(1).float()
neg_inds = gt.lt(1).float()
neg_weights = torch.pow(1 - gt, 4)
loss = 0
pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds
neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds
if mask is not None:
mask = mask[:, None, :, :].float()
pos_loss = pos_loss * mask
neg_loss = neg_loss * mask
num_pos = (pos_inds.float() * mask).sum()
else:
num_pos = pos_inds.float().sum()
pos_loss = pos_loss.sum()
neg_loss = neg_loss.sum()
if num_pos == 0:
loss = loss - neg_loss
else:
loss = loss - (pos_loss + neg_loss) / num_pos
return loss
class FocalLossCenterNet(nn.Module):
"""
Refer to https://github.com/tianweiy/CenterPoint
"""
def __init__(self):
super(FocalLossCenterNet, self).__init__()
self.neg_loss = neg_loss_cornernet
def forward(self, out, target, mask=None):
return self.neg_loss(out, target, mask=mask)
def _reg_loss(regr, gt_regr, mask):
"""
Refer to https://github.com/tianweiy/CenterPoint
L1 regression loss
Args:
regr (batch x max_objects x dim)
gt_regr (batch x max_objects x dim)
mask (batch x max_objects)
Returns:
"""
num = mask.float().sum()
mask = mask.unsqueeze(2).expand_as(gt_regr).float()
isnotnan = (~ torch.isnan(gt_regr)).float()
mask *= isnotnan
regr = regr * mask
gt_regr = gt_regr * mask
loss = torch.abs(regr - gt_regr)
loss = loss.transpose(2, 0)
loss = torch.sum(loss, dim=2)
loss = torch.sum(loss, dim=1)
# else:
# # D x M x B
# loss = loss.reshape(loss.shape[0], -1)
# loss = loss / (num + 1e-4)
loss = loss / torch.clamp_min(num, min=1.0)
# import pdb; pdb.set_trace()
return loss
def _gather_feat(feat, ind, mask=None):
dim = feat.size(2)
ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)
feat = feat.gather(1, ind)
if mask is not None:
mask = mask.unsqueeze(2).expand_as(feat)
feat = feat[mask]
feat = feat.view(-1, dim)
return feat
def _transpose_and_gather_feat(feat, ind):
feat = feat.permute(0, 2, 3, 1).contiguous()
feat = feat.view(feat.size(0), -1, feat.size(3))
feat = _gather_feat(feat, ind)
return feat
class RegLossCenterNet(nn.Module):
"""
Refer to https://github.com/tianweiy/CenterPoint
"""
def __init__(self):
super(RegLossCenterNet, self).__init__()
def forward(self, output, mask, ind=None, target=None):
"""
Args:
output: (batch x dim x h x w) or (batch x max_objects)
mask: (batch x max_objects)
ind: (batch x max_objects)
target: (batch x max_objects x dim)
Returns:
"""
if ind is None:
pred = output
else:
pred = _transpose_and_gather_feat(output, ind)
loss = _reg_loss(pred, target, mask)
return loss
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment