Commit c42ad958 authored by zhangwenwei's avatar zhangwenwei
Browse files

Merge branch 'voting_module' into 'master'

Voting module

See merge request open-mmlab/mmdet.3d!45
parents b2db1753 96c2c4e8
......@@ -295,6 +295,11 @@ def indoor_eval(gt_annos, dt_annos, metric, label2cat):
'constant')
gt_infos.append(
dict(boxes_3d=bbox_lidar_bottom, labels_3d=gt_anno['class']))
else:
gt_infos.append(
dict(
boxes_3d=np.array([], dtype=np.float32),
labels_3d=np.array([], dtype=np.int64)))
result_str = str()
result_str += 'mAP'
......@@ -303,12 +308,13 @@ def indoor_eval(gt_annos, dt_annos, metric, label2cat):
for i, iou_thresh in enumerate(metric):
rec_list = []
for label in ap[i].keys():
ret_dict[f'{label2cat[label]}_AP_{iou_thresh:.2f}'] = ap[i][label][
0]
ret_dict[f'mAP_{iou_thresh:.2f}'] = np.mean(list(ap[i].values()))
ret_dict[f'{label2cat[label]}_AP_{iou_thresh:.2f}'] = float(
ap[i][label][0])
ret_dict[f'mAP_{iou_thresh:.2f}'] = float(
np.mean(list(ap[i].values())))
for label in rec[i].keys():
ret_dict[f'{label2cat[label]}_rec_{iou_thresh:.2f}'] = rec[i][
label][-1]
ret_dict[f'{label2cat[label]}_rec_{iou_thresh:.2f}'] = float(
rec[i][label][-1])
rec_list.append(rec[i][label][-1])
ret_dict[f'mAR_{iou_thresh:.2f}'] = np.mean(rec_list)
ret_dict[f'mAR_{iou_thresh:.2f}'] = float(np.mean(rec_list))
return ret_dict
......@@ -3,6 +3,7 @@ import tempfile
import mmcv
import numpy as np
from mmcv.utils import print_log
from torch.utils.data import Dataset
from mmdet.datasets import DATASETS
......@@ -114,26 +115,33 @@ class Custom3DDataset(Dataset):
mmcv.dump(outputs, out)
return outputs, tmp_dir
def evaluate(self, results, metric=None):
def evaluate(self, results, metric=None, iou_thr=(0.25, 0.5), logger=None):
"""Evaluate.
Evaluation in indoor protocol.
Args:
results (list[dict]): List of results.
metric (list[float]): AP IoU thresholds.
metric (str | list[str]): Metrics to be evaluated.
iou_thr (list[float]): AP IoU thresholds.
"""
from mmdet3d.core.evaluation import indoor_eval
assert isinstance(
results, list), f'Expect results to be list, got {type(results)}.'
assert len(results) > 0, f'Expect length of results > 0.'
assert isinstance(
results[0], dict
), f'Expect elements in results to be dict, got {type(results[0])}.'
assert len(metric) > 0, f'Expect length of metric > 0.'
gt_annos = [info['annos'] for info in self.data_infos]
label2cat = {i: cat_id for i, cat_id in enumerate(self.CLASSES)}
ret_dict = indoor_eval(gt_annos, results, metric, label2cat)
ret_dict = indoor_eval(gt_annos, results, iou_thr, label2cat)
result_str = str()
for key, val in ret_dict.items():
result_str += f'{key} : {val} \n'
mAP_25, mAP_50 = ret_dict['mAP_0.25'], ret_dict['mAP_0.50']
result_str += f'mAP(0.25): {mAP_25} mAP(0.50): {mAP_50}'
print_log('\n' + result_str, logger=logger)
return ret_dict
def __len__(self):
......
from mmdet.models.backbones import SSDVGG, HRNet, ResNet, ResNetV1d, ResNeXt
from .pointnet2_sa import PointNet2SA
from .pointnet2_sa_ssg import PointNet2SASSG
from .second import SECOND
__all__ = [
'ResNet', 'ResNetV1d', 'ResNeXt', 'SSDVGG', 'HRNet', 'SECOND',
'PointNet2SA'
'PointNet2SASSG'
]
......@@ -7,9 +7,8 @@ from mmdet.models import BACKBONES
@BACKBONES.register_module()
class PointNet2SA(nn.Module):
"""PointNet2 using set abstraction (SA) and feature propagation (FP)
modules.
class PointNet2SASSG(nn.Module):
"""PointNet2 with Single-scale grouping.
Args:
in_channels (int): input channels of point cloud.
......
......@@ -13,6 +13,7 @@ from .roiaware_pool3d import (RoIAwarePool3d, points_in_boxes_cpu,
points_in_boxes_gpu)
from .sparse_block import (SparseBasicBlock, SparseBottleneck,
make_sparse_convmodule)
from .vote_module import VoteModule
from .voxel import DynamicScatter, Voxelization, dynamic_scatter, voxelization
__all__ = [
......@@ -25,5 +26,5 @@ __all__ = [
'make_sparse_convmodule', 'ball_query', 'furthest_point_sample',
'three_interpolate', 'three_nn', 'gather_points', 'grouping_operation',
'group_points', 'GroupAll', 'QueryAndGroup', 'PointSAModule',
'PointSAModuleMSG', 'PointFPModule'
'PointSAModuleMSG', 'PointFPModule', 'VoteModule'
]
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from torch.nn.functional import l1_loss, mse_loss, smooth_l1_loss
class VoteModule(nn.Module):
"""Vote module.
Generate votes from seed point features.
Args:
in_channels (int): Number of channels of seed point features.
vote_per_seed (int): Number of votes generated from each seed point.
gt_per_seed (int): Number of ground truth votes generated
from each seed point.
conv_channels (tuple[int]): Out channels of vote
generating convolution.
conv_cfg (dict): Config of convolution.
Default: dict(type='Conv1d').
norm_cfg (dict): Config of normalization.
Default: dict(type='BN1d').
norm_feats (bool): Whether to normalize features.
Default: True.
loss_weight (float): Weight of voting loss.
"""
def __init__(self,
in_channels,
vote_per_seed=1,
gt_per_seed=3,
conv_channels=(16, 16),
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
norm_feats=True,
loss_weight=1.0):
super().__init__()
self.in_channels = in_channels
self.vote_per_seed = vote_per_seed
self.gt_per_seed = gt_per_seed
self.norm_feats = norm_feats
self.loss_weight = loss_weight
prev_channels = in_channels
vote_conv_list = list()
for k in range(len(conv_channels)):
vote_conv_list.append(
ConvModule(
prev_channels,
conv_channels[k],
1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
bias=True,
inplace=True))
prev_channels = conv_channels[k]
self.vote_conv = nn.Sequential(*vote_conv_list)
# conv_out predicts coordinate and residual features
out_channel = (3 + in_channels) * self.vote_per_seed
self.conv_out = nn.Conv1d(prev_channels, out_channel, 1)
def forward(self, seed_points, seed_feats):
"""forward.
Args:
seed_points (Tensor): (B, N, 3) coordinate of the seed points.
seed_feats (Tensor): (B, C, N) features of the seed points.
Returns:
tuple[Tensor]:
- vote_points: Voted xyz based on the seed points
with shape (B, M, 3) M=num_seed*vote_per_seed.
- vote_features: Voted features based on the seed points with
shape (B, C, M) where M=num_seed*vote_per_seed,
C=vote_feature_dim.
"""
batch_size, feat_channels, num_seed = seed_feats.shape
num_vote = num_seed * self.vote_per_seed
x = self.vote_conv(seed_feats)
# (batch_size, (3+out_dim)*vote_per_seed, num_seed)
votes = self.conv_out(x)
votes = votes.transpose(2, 1).view(batch_size, num_seed,
self.vote_per_seed, -1)
offset = votes[:, :, :, 0:3]
res_feats = votes[:, :, :, 3:]
vote_points = (seed_points.unsqueeze(2) + offset).contiguous()
vote_points = vote_points.view(batch_size, num_vote, 3)
vote_feats = (seed_feats.transpose(2, 1).unsqueeze(2) +
res_feats).contiguous()
vote_feats = vote_feats.view(batch_size, num_vote,
feat_channels).transpose(2,
1).contiguous()
if self.norm_feats:
features_norm = torch.norm(vote_feats, p=2, dim=1)
vote_feats = vote_feats.div(features_norm.unsqueeze(1))
return vote_points, vote_feats
def get_loss(self, seed_points, vote_points, seed_indices,
vote_targets_mask, vote_targets):
"""Calculate loss of voting module.
Args:
seed_points (Tensor): coordinate of the seed points.
vote_points (Tensor): coordinate of the vote points.
seed_indices (Tensor): indices of seed points in raw points.
vote_targets_mask (Tensor): mask of valid vote targets.
vote_targets (Tensor): targets of votes.
Returns:
Tensor: weighted vote loss.
"""
batch_size, num_seed = seed_points.shape[:2]
seed_gt_votes_mask = torch.gather(vote_targets_mask, 1,
seed_indices).float()
pos_num = torch.sum(seed_gt_votes_mask)
seed_indices_expand = seed_indices.unsqueeze(-1).repeat(
1, 1, 3 * self.gt_per_seed)
seed_gt_votes = torch.gather(vote_targets, 1, seed_indices_expand)
seed_gt_votes += seed_points.repeat(1, 1, 3)
distance = self.nn_distance(
vote_points.view(batch_size * num_seed, -1, 3),
seed_gt_votes.view(batch_size * num_seed, -1, 3),
mode='l1')[2]
votes_distance = torch.min(distance, dim=1)[0]
votes_dist = votes_distance.view(batch_size, num_seed)
vote_loss = torch.sum(votes_dist * seed_gt_votes_mask) / (
pos_num + 1e-6)
return self.loss_weight * vote_loss
def nn_distance(self, points1, points2, mode='smooth_l1'):
"""Find the nearest neighbor from point1 to point2
Args:
points1 (Tensor): points to find the Nearest neighbor.
points2 (Tensor): points to find the Nearest neighbor.
mode (str): Specify the function (smooth_l1, l1 or l2)
to calculate distance.
Returns:
tuple[Tensor]:
- distance1: the nearest distance from points1 to points2.
- index1: the index of the nearest neighbor for points1.
- distance2: the nearest distance from points2 to points1.
- index2: the index of the nearest neighbor for points2.
"""
assert mode in ['smooth_l1', 'l1', 'l2']
N = points1.shape[1]
M = points2.shape[1]
pc1_expand_tile = points1.unsqueeze(2).repeat(1, 1, M, 1)
pc2_expand_tile = points2.unsqueeze(1).repeat(1, N, 1, 1)
if mode == 'smooth_l1':
pc_dist = torch.sum(
smooth_l1_loss(pc1_expand_tile, pc2_expand_tile), dim=-1)
elif mode == 'l1':
pc_dist = torch.sum(
l1_loss(pc1_expand_tile, pc2_expand_tile), dim=-1) # (B,N,M)
elif mode == 'l2':
pc_dist = torch.sum(
mse_loss(pc1_expand_tile, pc2_expand_tile), dim=-1) # (B,N,M)
else:
raise NotImplementedError
distance1, index1 = torch.min(pc_dist, dim=2) # (B,N)
distance2, index2 = torch.min(pc_dist, dim=1) # (B,M)
return distance1, index1, distance2, index2
......@@ -5,12 +5,12 @@ import torch
from mmdet3d.models import build_backbone
def test_pointnet2_sa():
def test_pointnet2_sa_ssg():
if not torch.cuda.is_available():
pytest.skip()
cfg = dict(
type='PointNet2SA',
type='PointNet2SASSG',
in_channels=6,
num_points=(32, 16),
radius=(0.8, 1.2),
......
import torch
def test_voting_module():
from mmdet3d.ops import VoteModule
self = VoteModule(vote_per_seed=3, in_channels=8)
seed_xyz = torch.rand([2, 64, 3], dtype=torch.float32) # (b, npoints, 3)
seed_features = torch.rand(
[2, 8, 64], dtype=torch.float32) # (b, in_channels, npoints)
# test forward
vote_xyz, vote_features = self(seed_xyz, seed_features)
assert vote_xyz.shape == torch.Size([2, 192, 3])
assert vote_features.shape == torch.Size([2, 8, 192])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment