Commit b64375c8 authored by wuyuefeng's avatar wuyuefeng Committed by zhangwenwei
Browse files

Feat pointnet2 modules

parent fb2120b9
...@@ -4,9 +4,11 @@ from mmdet.ops import (RoIAlign, SigmoidFocalLoss, get_compiler_version, ...@@ -4,9 +4,11 @@ from mmdet.ops import (RoIAlign, SigmoidFocalLoss, get_compiler_version,
from .ball_query import ball_query from .ball_query import ball_query
from .furthest_point_sample import furthest_point_sample from .furthest_point_sample import furthest_point_sample
from .gather_points import gather_points from .gather_points import gather_points
from .group_points import group_points, grouping_operation from .group_points import (GroupAll, QueryAndGroup, group_points,
grouping_operation)
from .interpolate import three_interpolate, three_nn from .interpolate import three_interpolate, three_nn
from .norm import NaiveSyncBatchNorm1d, NaiveSyncBatchNorm2d from .norm import NaiveSyncBatchNorm1d, NaiveSyncBatchNorm2d
from .pointnet_modules import PointFPModule, PointSAModule, PointSAModuleMSG
from .roiaware_pool3d import (RoIAwarePool3d, points_in_boxes_cpu, from .roiaware_pool3d import (RoIAwarePool3d, points_in_boxes_cpu,
points_in_boxes_gpu) points_in_boxes_gpu)
from .sparse_block import (SparseBasicBlock, SparseBottleneck, from .sparse_block import (SparseBasicBlock, SparseBottleneck,
...@@ -22,5 +24,6 @@ __all__ = [ ...@@ -22,5 +24,6 @@ __all__ = [
'RoIAwarePool3d', 'points_in_boxes_gpu', 'points_in_boxes_cpu', 'RoIAwarePool3d', 'points_in_boxes_gpu', 'points_in_boxes_cpu',
'make_sparse_convmodule', 'ball_query', 'furthest_point_sample', 'make_sparse_convmodule', 'ball_query', 'furthest_point_sample',
'three_interpolate', 'three_nn', 'gather_points', 'grouping_operation', 'three_interpolate', 'three_nn', 'gather_points', 'grouping_operation',
'group_points' 'group_points', 'GroupAll', 'QueryAndGroup', 'PointSAModule',
'PointSAModuleMSG', 'PointFPModule'
] ]
from .point_fp_module import PointFPModule
from .point_sa_module import PointSAModule, PointSAModuleMSG
__all__ = ['PointSAModuleMSG', 'PointSAModule', 'PointFPModule']
from typing import List
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmdet3d.ops import three_interpolate, three_nn
class PointFPModule(nn.Module):
"""Point feature propagation module used in PointNets.
Propagate the features from one set to another.
Args:
mlp_channels (list[int]): List of mlp channels.
norm_cfg (dict): Type of normalization method.
Default: dict(type='BN2d').
"""
def __init__(self,
mlp_channels: List[int],
norm_cfg: dict = dict(type='BN2d')):
super().__init__()
self.mlps = nn.Sequential()
for i in range(len(mlp_channels) - 1):
self.mlps.add_module(
f'layer{i}',
ConvModule(
mlp_channels[i],
mlp_channels[i + 1],
kernel_size=(1, 1),
stride=(1, 1),
conv_cfg=dict(type='Conv2d'),
norm_cfg=norm_cfg))
def forward(self, target: torch.Tensor, source: torch.Tensor,
target_feats: torch.Tensor,
source_feats: torch.Tensor) -> torch.Tensor:
"""forward.
Args:
target (Tensor): (B, n, 3) tensor of the xyz positions of
the target features.
source (Tensor): (B, m, 3) tensor of the xyz positions of
the source features.
target_feats (Tensor): (B, C1, n) tensor of the features to be
propagated to.
source_feats (Tensor): (B, C2, m) tensor of features
to be propagated.
Return:
Tensor: (B, M, N) M = mlp[-1], tensor of the target features.
"""
if source is not None:
dist, idx = three_nn(target, source)
dist_reciprocal = 1.0 / (dist + 1e-8)
norm = torch.sum(dist_reciprocal, dim=2, keepdim=True)
weight = dist_reciprocal / norm
interpolated_feats = three_interpolate(source_feats, idx, weight)
else:
interpolated_feats = source_feats.expand(*source_feats.size()[0:2],
target.size(1))
if target_feats is not None:
new_features = torch.cat([interpolated_feats, target_feats],
dim=1) # (B, C2 + C1, n)
else:
new_features = interpolated_feats
new_features = new_features.unsqueeze(-1)
new_features = self.mlps(new_features)
return new_features.squeeze(-1)
from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmdet3d.ops import (GroupAll, QueryAndGroup, furthest_point_sample,
gather_points)
class PointSAModuleMSG(nn.Module):
"""Point set abstraction module with multi-scale grouping used in Pointnets.
Args:
num_point (int): Number of points.
radii (list[float]): List of radius in each ball query.
sample_nums (list[int]): Number of samples in each ball query.
mlp_channels (list[int]): Specify of the pointnet before
the global pooling for each scale.
norm_cfg (dict): Type of normalization method.
Default: dict(type='BN2d').
use_xyz (bool): Whether to use xyz.
Default: True.
pool_mod (str): Type of pooling method.
Default: 'max_pool'.
normalize_xyz (bool): Whether to normalize local XYZ with radius.
Default: False.
"""
def __init__(self,
num_point: int,
radii: List[float],
sample_nums: List[int],
mlp_channels: List[List[int]],
norm_cfg: dict = dict(type='BN2d'),
use_xyz: bool = True,
pool_mod='max',
normalize_xyz: bool = False):
super().__init__()
assert len(radii) == len(sample_nums) == len(mlp_channels)
assert pool_mod in ['max', 'avg']
self.num_point = num_point
self.pool_mod = pool_mod
self.groupers = nn.ModuleList()
self.mlps = nn.ModuleList()
for i in range(len(radii)):
radius = radii[i]
sample_num = sample_nums[i]
if num_point is not None:
grouper = QueryAndGroup(
radius,
sample_num,
use_xyz=use_xyz,
normalize_xyz=normalize_xyz)
else:
grouper = GroupAll(use_xyz)
self.groupers.append(grouper)
mlp_spec = mlp_channels[i]
if use_xyz:
mlp_spec[0] += 3
mlp = nn.Sequential()
for i in range(len(mlp_spec) - 1):
mlp.add_module(
f'layer{i}',
ConvModule(
mlp_spec[i],
mlp_spec[i + 1],
kernel_size=(1, 1),
stride=(1, 1),
conv_cfg=dict(type='Conv2d'),
norm_cfg=norm_cfg))
self.mlps.append(mlp)
def forward(
self,
points_xyz: torch.Tensor,
features: torch.Tensor = None,
indices: torch.Tensor = None
) -> (torch.Tensor, torch.Tensor, torch.Tensor):
"""forward.
Args:
points_xyz (Tensor): (B, N, 3) xyz coordinates of the features.
features (Tensor): (B, C, N) features of each point.
Default: None.
indices (Tensor): (B, num_point) Index of the features.
Default: None.
Returns:
Tensor: (B, M, 3) where M is the number of points.
New features xyz.
Tensor: (B, M, sum_k(mlps[k][-1])) where M is the number
of points. New feature descriptors.
Tensor: (B, M) where M is the number of points.
Index of the features.
"""
new_features_list = []
xyz_flipped = points_xyz.transpose(1, 2).contiguous()
if indices is None:
indices = furthest_point_sample(points_xyz, self.num_point)
else:
assert (indices.shape[1] == self.num_point)
new_xyz = gather_points(xyz_flipped, indices).transpose(
1, 2).contiguous() if self.num_point is not None else None
for i in range(len(self.groupers)):
# (B, C, num_point, nsample)
new_features = self.groupers[i](points_xyz, new_xyz, features)
# (B, mlp[-1], num_point, nsample)
new_features = self.mlps[i](new_features)
if self.pool_mod == 'max':
# (B, mlp[-1], num_point, 1)
new_features = F.max_pool2d(
new_features, kernel_size=[1, new_features.size(3)])
elif self.pool_mod == 'avg':
# (B, mlp[-1], num_point, 1)
new_features = F.avg_pool2d(
new_features, kernel_size=[1, new_features.size(3)])
else:
raise NotImplementedError
new_features = new_features.squeeze(-1) # (B, mlp[-1], num_point)
new_features_list.append(new_features)
return new_xyz, torch.cat(new_features_list, dim=1), indices
class PointSAModule(PointSAModuleMSG):
"""Point set abstraction module used in Pointnets.
Args:
mlp_channels (list[int]): Specify of the pointnet before
the global pooling for each scale.
num_point (int): Number of points.
Default: None.
radius (float): Radius to group with.
Default: None.
num_sample (int): Number of samples in each ball query.
Default: None.
norm_cfg (dict): Type of normalization method.
Default: dict(type='BN2d').
use_xyz (bool): Whether to use xyz.
Default: True.
pool_mod (str): Type of pooling method.
Default: 'max_pool'.
normalize_xyz (bool): Whether to normalize local XYZ with radius.
Default: False.
"""
def __init__(self,
mlp_channels: List[int],
num_point: int = None,
radius: float = None,
num_sample: int = None,
norm_cfg: dict = dict(type='BN2d'),
use_xyz: bool = True,
pool_mod: str = 'max',
normalize_xyz: bool = False):
super().__init__(
mlp_channels=[mlp_channels],
num_point=num_point,
radii=[radius],
sample_nums=[num_sample],
norm_cfg=norm_cfg,
use_xyz=use_xyz,
pool_mod=pool_mod,
normalize_xyz=normalize_xyz)
import numpy as np
import torch
def test_pointnet_sa_module_msg():
from mmdet3d.ops import PointSAModuleMSG
self = PointSAModuleMSG(
num_point=16,
radii=[0.2, 0.4],
sample_nums=[4, 8],
mlp_channels=[[12, 16], [12, 32]],
norm_cfg=dict(type='BN2d'),
use_xyz=False,
pool_mod='max').cuda()
assert self.mlps[0].layer0.conv.in_channels == 12
assert self.mlps[0].layer0.conv.out_channels == 16
assert self.mlps[1].layer0.conv.in_channels == 12
assert self.mlps[1].layer0.conv.out_channels == 32
xyz = np.load('tests/data/sunrgbd/sunrgbd_trainval/lidar/000001.npy')
# (B, N, 3)
xyz = torch.from_numpy(xyz[..., :3]).view(1, -1, 3).cuda()
# (B, C, N)
features = xyz.repeat([1, 1, 4]).transpose(1, 2).contiguous().cuda()
# test forward
new_xyz, new_features, inds = self(xyz, features)
assert new_xyz.shape == torch.Size([1, 16, 3])
assert new_features.shape == torch.Size([1, 48, 16])
assert inds.shape == torch.Size([1, 16])
def test_pointnet_sa_module():
from mmdet3d.ops import PointSAModule
self = PointSAModule(
num_point=16,
radius=0.2,
num_sample=8,
mlp_channels=[12, 32],
norm_cfg=dict(type='BN2d'),
use_xyz=True,
pool_mod='max').cuda()
assert self.mlps[0].layer0.conv.in_channels == 15
assert self.mlps[0].layer0.conv.out_channels == 32
xyz = np.load('tests/data/sunrgbd/sunrgbd_trainval/lidar/000001.npy')
# (B, N, 3)
xyz = torch.from_numpy(xyz[..., :3]).view(1, -1, 3).cuda()
# (B, C, N)
features = xyz.repeat([1, 1, 4]).transpose(1, 2).contiguous().cuda()
# test forward
new_xyz, new_features, inds = self(xyz, features)
assert new_xyz.shape == torch.Size([1, 16, 3])
assert new_features.shape == torch.Size([1, 32, 16])
assert inds.shape == torch.Size([1, 16])
def test_pointnet_fp_module():
from mmdet3d.ops import PointFPModule
self = PointFPModule(mlp_channels=[24, 16]).cuda()
assert self.mlps.layer0.conv.in_channels == 24
assert self.mlps.layer0.conv.out_channels == 16
xyz = np.load('tests/data/sunrgbd/sunrgbd_trainval/lidar/000001.npy')
# (B, N, 3)
xyz1 = torch.from_numpy(xyz[0::2, :3]).view(1, -1, 3).cuda()
# (B, C1, N)
features1 = xyz1.repeat([1, 1, 4]).transpose(1, 2).contiguous().cuda()
# (B, M, 3)
xyz2 = torch.from_numpy(xyz[1::3, :3]).view(1, -1, 3).cuda()
# (B, C2, N)
features2 = xyz2.repeat([1, 1, 4]).transpose(1, 2).contiguous().cuda()
fp_features = self(xyz1, xyz2, features1, features2)
assert fp_features.shape == torch.Size([1, 16, 50])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment