Commit a6b97201 authored by zhangwenwei's avatar zhangwenwei
Browse files

Merge branch 'feature_semantic_head' into 'master'

semantic head with unittest

See merge request open-mmlab/mmdet.3d!13
parents ff17dc45 74faed7f
...@@ -190,3 +190,23 @@ def rotation_2d(points, angles): ...@@ -190,3 +190,23 @@ def rotation_2d(points, angles):
rot_cos = torch.cos(angles) rot_cos = torch.cos(angles)
rot_mat_T = torch.stack([[rot_cos, -rot_sin], [rot_sin, rot_cos]]) rot_mat_T = torch.stack([[rot_cos, -rot_sin], [rot_sin, rot_cos]])
return torch.einsum('aij,jka->aik', points, rot_mat_T) return torch.einsum('aij,jka->aik', points, rot_mat_T)
def enlarge_box3d_lidar(boxes3d, extra_width):
"""Enlarge the length, width and height of input boxes
Args:
boxes3d (torch.float32 or numpy.float32): bottom_center with
shape [N, 7], (x, y, z, w, l, h, ry) in LiDAR coords
extra_width (float): a fix number to add
Returns:
torch.float32 or numpy.float32: enlarged boxes
"""
if isinstance(boxes3d, np.ndarray):
large_boxes3d = boxes3d.copy()
else:
large_boxes3d = boxes3d.clone()
large_boxes3d[:, 3:6] += extra_width * 2
large_boxes3d[:, 2] -= extra_width # bottom center z minus extra_width
return large_boxes3d
from mmdet.models.losses import FocalLoss, SmoothL1Loss from mmdet.models.losses import FocalLoss, SmoothL1Loss, binary_cross_entropy
__all__ = ['FocalLoss', 'SmoothL1Loss'] __all__ = ['FocalLoss', 'SmoothL1Loss', 'binary_cross_entropy']
from .mask_heads import PointwiseSemanticHead
__all__ = ['PointwiseSemanticHead']
from .pointwise_semantic_head import PointwiseSemanticHead
__all__ = ['PointwiseSemanticHead']
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmdet3d.core import multi_apply
from mmdet3d.core.bbox import box_torch_ops
from mmdet3d.models.builder import build_loss
from mmdet3d.ops.roiaware_pool3d import points_in_boxes_gpu
from mmdet.models import HEADS
@HEADS.register_module
class PointwiseSemanticHead(nn.Module):
"""Semantic segmentation head for point-wise segmentation.
Predict point-wise segmentation and part regression results for PartA2.
See https://arxiv.org/abs/1907.03670 for more detials.
Args:
in_channels (int): the number of input channel.
num_classes (int): the number of class.
extra_width (float): boxes enlarge width.
loss_seg (dict): Config of segmentation loss.
loss_part (dict): Config of part prediction loss.
"""
def __init__(self,
in_channels,
num_classes=3,
extra_width=0.2,
seg_score_thr=0.3,
loss_seg=dict(
type='FocalLoss',
use_sigmoid=True,
reduction='sum',
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_part=dict(
type='CrossEntropyLoss',
use_sigmoid=True,
loss_weight=1.0)):
super(PointwiseSemanticHead, self).__init__()
self.extra_width = extra_width
self.num_classes = num_classes
self.seg_score_thr = seg_score_thr
self.seg_cls_layer = nn.Linear(in_channels, 1, bias=True)
self.seg_reg_layer = nn.Linear(in_channels, 3, bias=True)
self.loss_seg = build_loss(loss_seg)
self.loss_part = build_loss(loss_part)
def forward(self, x):
seg_preds = self.seg_cls_layer(x) # (N, 1)
part_preds = self.seg_reg_layer(x) # (N, 3)
seg_scores = torch.sigmoid(seg_preds).detach()
seg_mask = (seg_scores > self.seg_score_thr)
part_offsets = torch.sigmoid(part_preds).clone().detach()
part_offsets[seg_mask.view(-1) == 0] = 0
part_feats = torch.cat((part_offsets, seg_scores),
dim=-1) # shape (npoints, 4)
return dict(
seg_preds=seg_preds, part_preds=part_preds, part_feats=part_feats)
def get_targets_single(self, voxel_centers, gt_bboxes_3d, gt_labels_3d):
"""generate segmentation and part prediction targets
Args:
voxel_centers (torch.Tensor): shape [voxel_num, 3],
the center of voxels
gt_bboxes_3d (torch.Tensor): shape [box_num, 7], gt boxes
gt_labels_3d (torch.Tensor): shape [box_num], class label of gt
Returns:
tuple : segmentation targets with shape [voxel_num]
part prediction targets with shape [voxel_num, 3]
"""
enlarged_gt_boxes = box_torch_ops.enlarge_box3d_lidar(
gt_bboxes_3d, extra_width=self.extra_width)
part_targets = voxel_centers.new_zeros((voxel_centers.shape[0], 3),
dtype=torch.float32)
box_idx = points_in_boxes_gpu(
voxel_centers.unsqueeze(0),
gt_bboxes_3d.unsqueeze(0)).squeeze(0) # -1 ~ box_num
enlarge_box_idx = points_in_boxes_gpu(
voxel_centers.unsqueeze(0),
enlarged_gt_boxes.unsqueeze(0)).squeeze(0).long() # -1 ~ box_num
gt_labels_pad = F.pad(
gt_labels_3d, (1, 0), mode='constant', value=self.num_classes)
seg_targets = gt_labels_pad[(box_idx.long() + 1)]
fg_pt_flag = box_idx > -1
ignore_flag = fg_pt_flag ^ (enlarge_box_idx > -1)
seg_targets[ignore_flag] = -1
for k in range(gt_bboxes_3d.shape[0]):
k_box_flag = box_idx == k
# no point in current box (caused by velodyne reduce)
if not k_box_flag.any():
continue
fg_voxels = voxel_centers[k_box_flag]
transformed_voxels = fg_voxels - gt_bboxes_3d[k, 0:3]
transformed_voxels = box_torch_ops.rotation_3d_in_axis(
transformed_voxels.unsqueeze(0),
-gt_bboxes_3d[k, 6].view(1),
axis=2)
part_targets[k_box_flag] = transformed_voxels / gt_bboxes_3d[
k, 3:6] + voxel_centers.new_tensor([0.5, 0.5, 0])
part_targets = torch.clamp(part_targets, min=0)
return seg_targets, part_targets
def get_targets(self, voxels_dict, gt_bboxes_3d, gt_labels_3d):
batch_size = len(gt_labels_3d)
voxel_center_list = []
for idx in range(batch_size):
coords_idx = voxels_dict['coors'][:, 0] == idx
voxel_center_list.append(voxels_dict['voxel_centers'][coords_idx])
seg_targets, part_targets = multi_apply(self.get_targets_single,
voxel_center_list,
gt_bboxes_3d, gt_labels_3d)
seg_targets = torch.cat(seg_targets, dim=0)
part_targets = torch.cat(part_targets, dim=0)
return dict(seg_targets=seg_targets, part_targets=part_targets)
def loss(self, seg_preds, part_preds, seg_targets, part_targets):
"""Calculate point-wise segmentation and part prediction losses.
Args:
seg_preds (torch.Tensor): prediction of binary
segmentation with shape [voxel_num, 1].
part_preds (torch.Tensor): prediction of part
with shape [voxel_num, 3].
seg_targets (torch.Tensor): target of segmentation
with shape [voxel_num, 1].
part_targets (torch.Tensor): target of part with
shape [voxel_num, 3].
Returns:
dict: loss of segmentation and part prediction.
"""
pos_mask = (seg_targets > -1) & (seg_targets < self.num_classes)
binary_seg_target = pos_mask.long()
pos = pos_mask.float()
neg = (seg_targets == self.num_classes).float()
seg_weights = pos + neg
pos_normalizer = pos.sum()
seg_weights = seg_weights / torch.clamp(pos_normalizer, min=1.0)
loss_seg = self.loss_seg(seg_preds, binary_seg_target, seg_weights)
if pos_normalizer > 0:
loss_part = self.loss_part(part_preds[pos_mask],
part_targets[pos_mask])
else:
# fake a part loss
loss_part = loss_seg.new_tensor(0)
return dict(loss_seg=loss_seg, loss_part=loss_part)
...@@ -6,8 +6,8 @@ from mmdet3d.ops.roiaware_pool3d import (RoIAwarePool3d, points_in_boxes_cpu, ...@@ -6,8 +6,8 @@ from mmdet3d.ops.roiaware_pool3d import (RoIAwarePool3d, points_in_boxes_cpu,
def test_RoIAwarePool3d(): def test_RoIAwarePool3d():
if not torch.cuda.is_available( # RoIAwarePool3d only support gpu version currently.
): # RoIAwarePool3d only support gpu version currently. if not torch.cuda.is_available():
pytest.skip('test requires GPU and torch+cuda') pytest.skip('test requires GPU and torch+cuda')
roiaware_pool3d_max = RoIAwarePool3d( roiaware_pool3d_max = RoIAwarePool3d(
out_size=4, max_pts_per_voxel=128, mode='max') out_size=4, max_pts_per_voxel=128, mode='max')
......
import pytest
import torch
def test_PointwiseSemanticHead():
# PointwiseSemanticHead only support gpu version currently.
if not torch.cuda.is_available():
pytest.skip('test requires GPU and torch+cuda')
from mmdet3d.models.builder import build_head
head_cfg = dict(
type='PointwiseSemanticHead',
in_channels=8,
extra_width=0.2,
seg_score_thr=0.3,
num_classes=3,
loss_seg=dict(
type='FocalLoss',
use_sigmoid=True,
reduction='sum',
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_part=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0))
self = build_head(head_cfg)
self.cuda()
# test forward
voxel_features = torch.rand([4, 8], dtype=torch.float32).cuda()
feats_dict = self.forward(voxel_features)
assert feats_dict['seg_preds'].shape == torch.Size(
[voxel_features.shape[0], 1])
assert feats_dict['part_preds'].shape == torch.Size(
[voxel_features.shape[0], 3])
assert feats_dict['part_feats'].shape == torch.Size(
[voxel_features.shape[0], 4])
voxel_centers = torch.tensor(
[[6.56126, 0.9648336, -1.7339306], [6.8162713, -2.480431, -1.3616394],
[11.643568, -4.744306, -1.3580885], [23.482342, 6.5036807, 0.5806964]
],
dtype=torch.float32).cuda() # n, point_features
coordinates = torch.tensor(
[[0, 12, 819, 131], [0, 16, 750, 136], [1, 16, 705, 232],
[1, 35, 930, 469]],
dtype=torch.int32).cuda() # n, 4(batch, ind_x, ind_y, ind_z)
voxel_dict = dict(voxel_centers=voxel_centers, coors=coordinates)
gt_bboxes = list(
torch.tensor(
[[[6.4118, -3.4305, -1.7291, 1.7033, 3.4693, 1.6197, -0.9091]],
[[16.9107, 9.7925, -1.9201, 1.6097, 3.2786, 1.5307, -2.4056]]],
dtype=torch.float32).cuda())
gt_labels = list(torch.tensor([[0], [1]], dtype=torch.int64).cuda())
# test get_targets
target_dict = self.get_targets(voxel_dict, gt_bboxes, gt_labels)
assert target_dict['seg_targets'].shape == torch.Size(
[voxel_features.shape[0]])
assert target_dict['part_targets'].shape == torch.Size(
[voxel_features.shape[0], 3])
# test loss
loss_dict = self.loss(feats_dict['seg_preds'], feats_dict['part_preds'],
target_dict['seg_targets'],
target_dict['part_targets'])
assert loss_dict['loss_seg'] > 0
assert loss_dict['loss_part'] == 0 # no points in gt_boxes
total_loss = loss_dict['loss_seg'] + loss_dict['loss_part']
total_loss.backward()
if __name__ == '__main__':
test_PointwiseSemanticHead()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment